Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Iteration using Range without providing a start is slower than providing a start #3931

Open
sstadick opened this issue Jan 7, 2025 · 5 comments
Labels
bug Something isn't working mojo-repo Tag all issues with this label

Comments

@sstadick
Copy link

sstadick commented Jan 7, 2025

Bug description

Iterating over a List or Span is slower when using Range(len(list)) than when using either Range(0, len(list)) or the direct for value in list.

Below is a minimal reproducible example, although I ran into this while benchmarking other code.

import benchmark
from sys.intrinsics import assume
from benchmark import keep, Unit
from memory import Span


fn main() raises:
    var input = List[UInt8]()
    for i in range(100000):
        input.append(i)
    var data = Span(input)

    @parameter
    fn range_without_start() raises:
        var sum = 0
        for i in range(len(data)):
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_without_start_assume() raises:
        var sum = 0
        for i in range(len(data)):
            assume(i < len(data))
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_with_start() raises:
        var sum = 0
        for i in range(0, len(data)):
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_with_start_assume() raises:
        var sum = 0
        for i in range(0, len(data)):
            assume(i < len(data))
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_iter() raises:
        var sum = 0
        for value in data:
            sum += int(value[])
        keep(sum)

    print("Without start")
    var report = benchmark.run[range_without_start]()
    report.print(Unit.ms)
    print("Without start with assume")
    report = benchmark.run[range_without_start_assume]()
    report.print(Unit.ms)
    print("With start")
    report = benchmark.run[range_with_start]()
    report.print(Unit.ms)
    print("With start with assume")
    report = benchmark.run[range_with_start_assume]()
    report.print(Unit.ms)
    print("With direct iter")
    report = benchmark.run[range_iter]()
    report.print(Unit.ms)
Without start
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.043647766695576753
Total: 2415.642
Iters: 55344
Warmup Total: 0.043
Fastest Mean: 0.043647766695576753
Slowest Mean: 0.043647766695576753

Without start with assume
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.04344013745204633
Total: 2389.251
Iters: 55001
Warmup Total: 0.043
Fastest Mean: 0.04344013745204633
Slowest Mean: 0.04344013745204633

With start
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.03185076622576459
Total: 2429.672
Iters: 76283
Warmup Total: 0.031
Fastest Mean: 0.03185076622576458
Slowest Mean: 0.03185076622576458

With start with assume
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.031492011834319524
Total: 2235.303
Iters: 70980
Warmup Total: 0.031
Fastest Mean: 0.03149201183431953
Slowest Mean: 0.03149201183431953

With direct iter
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.0315751325175054
Total: 2412.498
Iters: 76405
Warmup Total: 0.031
Fastest Mean: 0.0315751325175054
Slowest Mean: 0.0315751325175054

Steps to reproduce

  • Include relevant code snippet or link to code that did not work as expected.
  • If applicable, add screenshots to help explain the problem.
  • If using the Playground, name the pre-existing notebook that failed and the steps that led to failure.
  • Include anything else that might help us debug the issue.

System information

- What OS did you do install Mojo on ?
MacOS
- Provide version information for Mojo by pasting the output of `mojo -v`

❯ magic run mojo -v
mojo 24.6.0 (4487cd6e)
  • Provide Magic CLI version by pasting the output of magic -V or magic --version
❯ magic -V
magic 0.5.1 - (based on pixi 0.37.0)
  • Optionally, provide more information with magic info.
@sstadick sstadick added bug Something isn't working mojo-repo Tag all issues with this label labels Jan 7, 2025
@sstadick
Copy link
Author

sstadick commented Jan 7, 2025

Relatedly and anecdotally (working on a good reproducible example) Mojo does not seem to be optimizing looping over an array and accessing by index as well as Rust does. Specifically it seems to not be able to skip bounds checks.

@sstadick
Copy link
Author

sstadick commented Jan 7, 2025

An example of identical (I hope) Rust and Mojo code with the rust looping being faster:
https://github.com/sstadick/rust-vs-mojo-loop (I can make a separate issue for this if it's not a known thing / actually deemed an issue and not just poor benchmarking on my part).

import sys


fn main() raises:
    var times = sys.argv()[1].__int__()

    var array = List[UInt64]()
    for i in range(0, times):
        array.append(i)

    var sum: UInt64 = 0
    for _ in range(0, times):
        for i in range(0, times):
            sum += array[i]
    print(sum)

@sstadick
Copy link
Author

sstadick commented Jan 8, 2025

Assembly for the two versions of range:
range_without_start.txt
range_with_start.txt

Not an assembly expert, but the range with start loop seems like less instructions and maybe less range validation?

@sstadick
Copy link
Author

sstadick commented Jan 8, 2025

I'm reasonably sure this boils down to the _SequentialRange being more compiler friendly than the _ZeroStartingRange.
It looks like the _ZeroStartingRange is trying to take advantage of the fact that it can do one less operation by skipping the max call for calls to __len__(), bug for unclear reasons this make it harder for the compiler to optimize the loop.

The fix, IMO would be to just have fn range[type: Intable](end: type) return _SequentialRange instead. But I could also see this being viewed as a possible compiler bug.

If the swap to _SequentialRange is acceptable, I'd be happy to make a PR with the updates which looks like it only touches builtin.range and utils.loop.

@sstadick
Copy link
Author

sstadick commented Jan 8, 2025

Demo of the proposed fix working and the performance matching:

import benchmark
from sys.intrinsics import assume
from benchmark import keep, Unit
from memory import Span

from builtin._stubs import _IntIterable
from builtin.range import _StridedRange


@register_passable("trivial")
struct _ZeroStartingRange(Sized, ReversibleRange, _IntIterable):
    var start: Int
    var end: Int

    @always_inline
    @implicit
    fn __init__(out self, end: Int):
        self.start = 0
        self.end = end

    @always_inline
    fn __iter__(self) -> Self:
        return self

    @always_inline
    fn __next__(mut self) -> Int:
        var start = self.start
        self.start += 1
        return start

    @always_inline
    fn __has_next__(self) -> Bool:
        return self.__len__() > 0

    @always_inline
    fn __len__(self) -> Int:
        return max(0, self.end - self.start)

    @always_inline
    fn __getitem__(self, idx: Int) -> Int:
        debug_assert(idx < self.__len__(), "index out of range")
        return self.start + index(idx)

    @always_inline
    fn __reversed__(self) -> _StridedRange:
        return range(self.end - 1, self.start - 1, -1)


fn main() raises:
    var input = List[UInt8]()
    for i in range(100000):
        input.append(i)
    var data = Span(input)
    var s1 = 0
    for i in _ZeroStartingRange(len(data)):
        s1 += int(data[i])
    var s2 = 0
    for i in range(0, len(data)):
        s2 += int(data[i])
    if s1 != s2:
        raise "Invalid custom iterator"

    @parameter
    fn range_custom_iter_implicit_zero() raises:
        var sum = 0
        for i in _ZeroStartingRange(len(data)):
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_custom_iter_implicit_zero_fixed() raises:
        var sum = 0
        for i in _ZeroStartingRange(len(data)):
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_without_start() raises:
        var sum = 0
        for i in range(len(data)):
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_without_start_assume() raises:
        var sum = 0
        for i in range(len(data)):
            assume(i < len(data))
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_with_start() raises:
        var sum = 0
        for i in range(0, len(data)):
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_with_start_assume() raises:
        var sum = 0
        for i in range(0, len(data)):
            assume(i < len(data))
            sum += int(data[i])
        keep(sum)

    @parameter
    fn range_iter() raises:
        var sum = 0
        for value in data:
            sum += int(value[])
        keep(sum)

    print("Custom Iter")
    var report = benchmark.run[range_custom_iter_implicit_zero]()
    report.print(Unit.ms)
    print("Without start")
    report = benchmark.run[range_without_start]()
    report.print(Unit.ms)
    print("Without start with assume")
    report = benchmark.run[range_without_start_assume]()
    report.print(Unit.ms)
    print("With start")
    report = benchmark.run[range_with_start]()
    report.print(Unit.ms)
    print("With start with assume")
    report = benchmark.run[range_with_start_assume]()
    report.print(Unit.ms)
    print("With direct iter")
    report = benchmark.run[range_iter]()
    report.print(Unit.ms)
Custom Iter
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.03149534626912762
Total: 2395.788
Iters: 76068
Warmup Total: 0.031
Fastest Mean: 0.03149534626912762
Slowest Mean: 0.03149534626912762

Without start
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.04372536524431087
Total: 2415.258
Iters: 55237
Warmup Total: 0.043
Fastest Mean: 0.04372536524431088
Slowest Mean: 0.04372536524431088

Without start with assume
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.044353636801541425
Total: 1841.563
Iters: 41520
Warmup Total: 0.043
Fastest Mean: 0.044353636801541425
Slowest Mean: 0.044353636801541425

With start
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.031408006803166084
Total: 2400.671
Iters: 76435
Warmup Total: 0.031
Fastest Mean: 0.03140800680316609
Slowest Mean: 0.03140800680316609

With start with assume
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.03140519595927915
Total: 2403.157
Iters: 76521
Warmup Total: 0.031
Fastest Mean: 0.03140519595927915
Slowest Mean: 0.03140519595927915

With direct iter
--------------------------------------------------------------------------------
Benchmark Report (ms)
--------------------------------------------------------------------------------
Mean: 0.03144227548670714
Total: 2403.196
Iters: 76432
Warmup Total: 0.031
Fastest Mean: 0.03144227548670714
Slowest Mean: 0.03144227548670714

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working mojo-repo Tag all issues with this label
Projects
None yet
Development

No branches or pull requests

1 participant