Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several improvements #95

Merged
merged 15 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
- uses: ./.github/workflows/actions/run_nox_session
with:
nox_session: ${{ matrix.session }}
- if: ${{ matrix.session == 'coverage'}}
- if: ${{ startsWith(matrix.session, 'coverage') }}
uses: codecov/codecov-action@v3
with:
files: ./codecov.xml
Expand Down
14 changes: 14 additions & 0 deletions examples/stqdm_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from time import sleep

import streamlit as st

from stqdm import stqdm

bar_format = st.selectbox(
"Select bar_format", ["", "{bar}", "{l_bar}{bar}{r_bar}", "{desc} {percentage:.0f}%", "blabla", None]
)

empty = st.container()

for item in stqdm(range(10), bar_format=bar_format, st_container=empty):
sleep(0.5)
10 changes: 10 additions & 0 deletions examples/stqdm_leave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from time import sleep

import streamlit as st

from stqdm import stqdm

leave = st.checkbox("Should leave progress bar")

for _ in stqdm(range(10), leave=leave):
sleep(0.1)
24 changes: 14 additions & 10 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,30 @@ def install_deps(session: nox.Session, constraint_groups: List[str], dependencie
with_python_versions(["3.8", "3.9"], "~=0.66.0", "~=4.50.0")
+ with_python_versions(["3.8", "3.9"], "~=0.66.0", "~=4.50.0")
+ with_python_versions(["3.8", "3.9"], "~=1.4.0", "~=4.50.0")
+ with_python_versions(["3.8", "3.9"], "~=1.4.0", "~=4.63.0")
+ with_python_versions(["3.8", "3.9", "3.10"], "~=1.8.0", "~=4.63.0")
+ with_python_versions(["3.8", "3.9", "3.10"], "~=1.12.0", "~=4.63.0")
+ with_python_versions(["3.8", "3.9"], "~=1.4.0", "~=4.66.1")
+ with_python_versions(["3.8", "3.9", "3.10"], "~=1.8.0", "~=4.66.1")
+ with_python_versions(["3.8", "3.9", "3.10"], "~=1.12.0", "~=4.66.1")
+ with_python_versions(["3.9", "3.10"], "~=1.12.0", LATEST)
+ with_python_versions(["3.10"], "~=1.22.0", LATEST)
+ with_python_versions(["3.9", "3.10"], LATEST, LATEST)
+ with_python_versions(["3.11"], "~=1.22.0", LATEST)
+ with_python_versions(["3.9", "3.10", "3.11"], LATEST, LATEST)
)


@nox.session
@nox.parametrize(["python", "streamlit_version", "tqdm_version"], PYTHON_ST_TQDM_VERSIONS)
def tests(session: nox.Session, streamlit_version: str, tqdm_version: str) -> None:
dependencies_to_install = build_dependencies_to_install_list(streamlit_version, tqdm_version, [".", "pytest"])
dependencies_to_install = build_dependencies_to_install_list(streamlit_version, tqdm_version, [".", "pytest", "freezegun"])
install_deps(session, constraint_groups=["dev"], dependencies_to_install=dependencies_to_install)
session.run("pytest")


@nox_poetry.session(python=None)
def coverage(session: nox_poetry.Session) -> None:
session.install("pytest", "pytest-cov", ".")
@nox.session(python=None)
@nox.parametrize(["python", "streamlit_version", "tqdm_version"], [PYTHON_ST_TQDM_VERSIONS[0]] + [PYTHON_ST_TQDM_VERSIONS[-1]])
def coverage(session: nox.Session, streamlit_version: str, tqdm_version: str) -> None:
dependencies_to_install = build_dependencies_to_install_list(
streamlit_version, tqdm_version, [".", "pytest", "pytest-cov", "freezegun"]
)
install_deps(session, constraint_groups=["dev"], dependencies_to_install=dependencies_to_install)
session.run("pytest", "--cov-fail-under=15", "--cov=stqdm", "--cov-report=xml:codecov.xml")


Expand All @@ -147,5 +151,5 @@ def black(session: nox_poetry.Session) -> None:

@nox_poetry.session(python=None)
def lint(session: nox_poetry.Session) -> None:
session.install("pylint", "nox", "nox_poetry", "tqdm", "streamlit")
session.install("pylint", "nox", "nox_poetry", "tqdm", "streamlit", "pytest", "freezegun")
session.run("pylint", "stqdm", "examples", "tests", "noxfile.py")
1,365 changes: 682 additions & 683 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ isort = "^5.12.0"
pylint = "^2.17.5"
nox = "^2023.4.22"
nox-poetry = "^1.0.3"
freezegun = "^1.4.0"

[tool.black]
line-length = 127
Expand Down Expand Up @@ -68,6 +69,7 @@ disable = [
max-line-length=140
docstring-min-length=15
max-args = 6
no-docstring-rgx = "^_|^test_|^Test[A-Z]" # no docstrings for tests

[tool.pylint.miscellaneous]
notes = ["FIXME"]
Expand Down
81 changes: 64 additions & 17 deletions stqdm/stqdm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import re
from typing import TYPE_CHECKING, Optional

import streamlit as st
from packaging import version
from tqdm.auto import tqdm

# pragma: no cover
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator

IS_TEXT_INSIDE_PROGRESS_AVAILABLE = version.parse(st.__version__) >= version.parse("1.18.0")
BAR_FORMAT_REGEX = re.compile(r"\{bar(?:[:!][a-zA-Z0-9]+){,2}}")


class stqdm(tqdm): # pylint: disable=invalid-name,inconsistent-mro
def __init__(
Expand Down Expand Up @@ -30,18 +41,32 @@ def __init__(
nrows=None,
colour=None,
gui=False,
st_container=None,
backend=False,
frontend=True,
st_container: Optional["DeltaGenerator"] = None,
backend: bool = False,
frontend: bool = True,
**kwargs,
): # pylint: disable=too-many-arguments,too-many-locals
if st_container is None:
st_container = st
self._backend = backend
self._frontend = frontend
self.st_container = st_container
self._st_progress_bar = None
self._st_text = None
self.st_container: "DeltaGenerator" = st_container
self._st_progress_bar: Optional["DeltaGenerator"] = None
self._st_text: Optional["DeltaGenerator"] = None

if ncols is None and not bar_format:
ncols = 0 # rely on standard tqdm way to not display progress bar
if bar_format:
original_bar_format = bar_format
bar_format = self.remove_bar_from_format(original_bar_format)
should_display_progress_bar = bar_format != original_bar_format
should_display_text = bool(bar_format.strip())
else:
should_display_progress_bar = True
should_display_text = True
self.should_display_progress_bar: bool = should_display_progress_bar
self.should_display_text: bool = should_display_text

super().__init__(
iterable=iterable,
desc=desc,
Expand Down Expand Up @@ -72,39 +97,61 @@ def __init__(
)

@property
def st_progress_bar(self) -> st.progress:
def st_progress_bar(self) -> "DeltaGenerator":
if self._st_progress_bar is None:
self._st_progress_bar = self.st_container.empty()
return self._st_progress_bar

@property
def st_text(self) -> st.empty:
def st_text(self) -> "DeltaGenerator":
if self._st_text is None:
self._st_text = self.st_container.empty()
return self._st_text

def st_display(self, n, total, **kwargs): # pylint: disable=invalid-name
if total is not None and total > 0:
self.st_text.write(self.format_meter(n, total, **{**kwargs, "ncols": 0}))
self.st_progress_bar.progress(n / total)
if total is None:
self.st_text.write(self.format_meter(n, total, **{**kwargs, "ncols": 0}))
def st_display(self, n: int, total: Optional[int], **kwargs) -> None: # pylint: disable=invalid-name
"""Display the progress bar and text in streamlit"""
if self.should_display_text:
meter_text = self.format_meter(n, total, **kwargs)
else:
meter_text = None

can_display_text = bool(meter_text)
can_display_progress_bar = total is not None and total > 0

def display(self, msg=None, pos=None):
if can_display_progress_bar and self.should_display_progress_bar:
if not can_display_text:
self.st_progress_bar.progress(n / total)
elif IS_TEXT_INSIDE_PROGRESS_AVAILABLE:
self.st_progress_bar.progress(n / total, text=meter_text)
else:
self.st_text.write(meter_text)
self.st_progress_bar.progress(n / total)
else:
if can_display_text:
self.st_text.write(meter_text)

def display(self, msg=None, pos=None) -> bool:
if self._backend:
super().display(msg, pos)
if self._frontend:
self.st_display(**self.format_dict)
return True

def st_clear(self):
def st_clear(self) -> None:
leave = self.pos == 0 if self.leave is None else self.leave
if leave:
return
if self._st_text is not None:
self._st_text.empty()
self._st_text = None
if self._st_progress_bar is not None:
self._st_progress_bar.empty()
self._st_progress_bar = None

def close(self):
def close(self) -> None:
super().close()
self.st_clear()

@staticmethod
def remove_bar_from_format(bar_format: str) -> str:
return re.sub(BAR_FORMAT_REGEX, "", bar_format)
125 changes: 116 additions & 9 deletions tests/test_streamlit_frontend.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,147 @@
from unittest.mock import patch
from datetime import timedelta
from typing import Optional
from unittest.mock import MagicMock, patch

import pytest
from freezegun import freeze_time
from tqdm import tqdm

from stqdm.stqdm import stqdm
from stqdm.stqdm import IS_TEXT_INSIDE_PROGRESS_AVAILABLE, stqdm

TQDM_RUN_EVERY_ITERATION = {
"mininterval": 0,
"miniters": 0,
}
DESCRIPTION = "progress_bar_description"


def test_works_out_of_streamlit():
for _ in stqdm(range(2)):
pass


def assert_frontend_as_been_called_with(stqdmed_iterator: stqdm, text: Optional[str], progress: Optional[float]):
if text is None and progress is None:
raise ValueError("Nothing to assert.")
if progress is None:
stqdmed_iterator.st_text.write.assert_called_with(text)
stqdmed_iterator.st_progress_bar.progress.assert_not_called()
elif text is None:
stqdmed_iterator.st_text.write.assert_not_called()
stqdmed_iterator.st_progress_bar.progress.assert_called_with(progress)
elif IS_TEXT_INSIDE_PROGRESS_AVAILABLE:
stqdmed_iterator.st_progress_bar.progress.assert_called_with(progress, text=text)
stqdmed_iterator.st_text.write.assert_not_called()
else:
stqdmed_iterator.st_text.write.assert_called_with(text)
stqdmed_iterator.st_progress_bar.progress.assert_called_with(progress)


@patch("streamlit.empty")
def test_writes_tqdm_description(_):
stqdmed_iterator = stqdm(range(2), **TQDM_RUN_EVERY_ITERATION)
for i, _ in enumerate(stqdmed_iterator):
stqdmed_iterator.st_text.write.assert_called_with(tqdm.format_meter(**{**stqdmed_iterator.format_dict, "ncols": 0}))
stqdmed_iterator.st_progress_bar.progress.assert_called_with(i / len(stqdmed_iterator))
assert_frontend_as_been_called_with(
stqdmed_iterator,
text=tqdm.format_meter(**{**stqdmed_iterator.format_dict, "ncols": 0}),
progress=i / len(stqdmed_iterator),
)


@patch("streamlit.empty")
def test_writes_tqdm_description_when_no_length_but_total(_):
stqdmed_iterator = stqdm((i for i in range(2)), total=2, **TQDM_RUN_EVERY_ITERATION)
stqdmed_iterator = stqdm((i for i in range(2)), total=3, **TQDM_RUN_EVERY_ITERATION)
for i, _ in enumerate(stqdmed_iterator):
stqdmed_iterator.st_text.write.assert_called_with(tqdm.format_meter(**{**stqdmed_iterator.format_dict, "ncols": 0}))
stqdmed_iterator.st_progress_bar.progress.assert_called_with(i / len(stqdmed_iterator))
assert_frontend_as_been_called_with(
stqdmed_iterator,
text=tqdm.format_meter(**{**stqdmed_iterator.format_dict, "ncols": 0}),
progress=i / 3,
)


@patch("streamlit.empty")
def test_writes_tqdm_description_when_no_length_no_total(_):
stqdmed_iterator = stqdm((i for i in range(2)), **TQDM_RUN_EVERY_ITERATION)

for _ in stqdmed_iterator:
stqdmed_iterator.st_text.write.assert_called_with(tqdm.format_meter(**{**stqdmed_iterator.format_dict, "ncols": 0}))
stqdmed_iterator.st_progress_bar.progress.assert_not_called()
assert_frontend_as_been_called_with(
stqdmed_iterator,
text=tqdm.format_meter(**{**stqdmed_iterator.format_dict, "ncols": 0}),
progress=None,
)


@patch("streamlit.empty")
@pytest.mark.parametrize(
"bar_format,get_text",
[
(None, lambda i, total: tqdm.format_meter(n=i, total=total, elapsed=i, ncols=0, prefix=DESCRIPTION)),
("{bar}", lambda i, total: None),
(
"{bar}{desc}",
lambda i, total: tqdm.format_meter(n=i, total=total, elapsed=i, bar_format="{desc}", prefix=DESCRIPTION),
),
],
)
def test_bar_format(_, bar_format, get_text):
with freeze_time("2020-01-01") as frozen_time:
stqdmed_iterator = stqdm(range(2), bar_format=bar_format, **TQDM_RUN_EVERY_ITERATION, desc=DESCRIPTION)
for i, _ in enumerate(stqdmed_iterator):
frozen_time.tick(timedelta(seconds=1))
assert_frontend_as_been_called_with(
stqdmed_iterator,
text=get_text(i=i, total=2),
progress=i / len(stqdmed_iterator),
)


@patch("streamlit.empty")
def test_leave_false_keeps_stqdm(_):
# pylint: disable=protected-access
stqdmed_iterator = stqdm(range(2), leave=True)
mock_progress_bar = MagicMock()
mock_text = MagicMock()
stqdmed_iterator._st_progress_bar = mock_progress_bar
stqdmed_iterator._st_text = mock_text
for _ in stqdmed_iterator:
pass
assert stqdmed_iterator._st_progress_bar is mock_progress_bar
mock_progress_bar.empty.assert_not_called()
assert stqdmed_iterator._st_text is mock_text
mock_text.empty.assert_not_called()


@patch("streamlit.empty")
def test_leave_true_remove_stqdm(_):
# pylint: disable=protected-access
stqdmed_iterator = stqdm(range(2), leave=False)
mock_progress_bar = MagicMock()
mock_text = MagicMock()
stqdmed_iterator._st_progress_bar = mock_progress_bar
stqdmed_iterator._st_text = mock_text
for _ in stqdmed_iterator:
pass
assert stqdmed_iterator._st_progress_bar is None
mock_progress_bar.empty.assert_called_once()
assert stqdmed_iterator._st_text is None
mock_text.empty.assert_called_once()


@pytest.mark.parametrize("frontend", [True, False])
@pytest.mark.parametrize("backend", [True, False])
@patch("streamlit.empty")
@patch.object(stqdm, "st_display")
@patch.object(tqdm, "display")
def test_use_stqdm_frontent_backend(tqdm_display_mock, st_display_mock, _, backend, frontend):
# pylint: disable=protected-access
stqdmed_iterator = stqdm(range(2), backend=backend, frontend=frontend, **TQDM_RUN_EVERY_ITERATION)
for _ in stqdmed_iterator:
pass
if backend:
tqdm_display_mock.assert_called()
else:
tqdm_display_mock.assert_not_called()
if frontend:
st_display_mock.assert_called()
else:
st_display_mock.assert_not_called()