Skip to content

Commit

Permalink
Implemented bulk replacing in files.
Browse files Browse the repository at this point in the history
  • Loading branch information
KOLANICH committed Sep 21, 2021
1 parent b25084a commit 315a627
Show file tree
Hide file tree
Showing 10 changed files with 499 additions and 11 deletions.
193 changes: 193 additions & 0 deletions https_everywhere/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import asyncio
import sys
import typing
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from os import cpu_count
from pathlib import Path

from binaryornot.check import is_binary
from plumbum import cli

from .core import CombinedReplacerFactory, ReplaceContext
from .core.InBufferReplacer import InBufferReplacer
from .core.InFileReplacer import InFileReplacer
from .replacers.HEReplacer import HEReplacer
from .replacers.HSTSPreloadReplacer import HSTSPreloadReplacer


class OurInBufferReplacer(InBufferReplacer):
__slots__ = ()
FACS = CombinedReplacerFactory(
{
"preloads": HSTSPreloadReplacer,
"heRulesets": HEReplacer,
}
)

def __init__(self, preloads=None, heRulesets=None):
super().__init__(preloads=preloads, heRulesets=heRulesets)


class OurInFileReplacer(InFileReplacer):
def __init__(self, preloads=None, heRulesets=None):
super().__init__(OurInBufferReplacer(preloads=preloads, heRulesets=heRulesets))


class CLI(cli.Application):
"""HTTPSEverywhere-like URI rewriter"""


class FileClassifier:
__slots__ = ("noSkipDot", "noSkipBinary")

def __init__(self, noSkipDot: bool, noSkipBinary: bool):
self.noSkipDot = noSkipDot
self.noSkipBinary = noSkipBinary

def __call__(self, p: Path) -> str:
for pa in p.parts:
if not self.noSkipDot and pa[0] == ".":
return "dotfile"

if not p.is_dir():
if p.is_file():
if self.noSkipBinary or not is_binary(p):
return ""
else:
return "binary"
else:
return "not regular file"


class FilesEnumerator:
__slots__ = ("classifier", "disallowedReportingCallback")

def __init__(self, classifier, disallowedReportingCallback):
self.classifier = classifier
self.disallowedReportingCallback = disallowedReportingCallback

def __call__(self, fileOrDir: Path):
reasonOfDisallowal = self.classifier(fileOrDir)
if not reasonOfDisallowal:
if fileOrDir.is_dir():
for f in fileOrDir.iterdir():
yield from self(f)
else:
yield fileOrDir
else:
self.disallowedReportingCallback(fileOrDir, reasonOfDisallowal)


@CLI.subcommand("bulk")
class FileRewriteCLI(cli.Application):
"""Rewrites URIs in files. Use - to consume list of files from stdin. Don't use `find`, it is a piece of shit which is impossible to configure to skip .git dirs."""

__slots__ = ("_repl",)

@property
def repl(self) -> InFileReplacer:
if self._repl is None:
self._repl = OurInFileReplacer()
print(
len(self._repl.inBufferReplacer.singleURIReplacer.children[0].preloads),
"HSTS preloads",
)
print(len(self._repl.inBufferReplacer.singleURIReplacer.children[1].rulesets), "HE rules")
return self._repl

def processEachFileName(self, ctx: ReplaceContext, l: str) -> Path:
l = l.strip()
if l:
l = l.decode("utf-8")
p = Path(l).resolve().absolute()
self.processEachFilePath(ctx, p)

def processEachFilePath(self, ctx: ReplaceContext, p: Path) -> None:
for pp in self.fe(p):
if self.trace:
print("Processing", pp, file=sys.stderr)
self.repl(ctx, pp)
if self.trace:
print("Processed", pp, file=sys.stderr)

@asyncio.coroutine
def asyncMainPathsFromStdIn(self):
conc = []
asyncStdin = asyncio.StreamReader(loop=self.loop)
yield from self.loop.connect_read_pipe(
lambda: asyncio.StreamReaderProtocol(asyncStdin, loop=self.loop), sys.stdin
)
with ThreadPoolExecutor(max_workers=cpu_count()) as pool:
while not asyncStdin.at_eof():
l = yield from asyncStdin.readline()
yield from self.loop.run_in_executor(pool, partial(self.processEachFileName, l))

@asyncio.coroutine
def asyncMainPathsFromCLI(self, filesOrDirs: typing.Iterable[typing.Union[Path, str]]):
try:
from tqdm import tqdm
except ImportError:

def tqdm(x):
return x

ctx = ReplaceContext(None)
replaceInEachFileWithContext = partial(self.repl, ctx)

with tqdm(filesOrDirs) as pb:
for fileOrDir in pb:
fileOrDir = Path(fileOrDir).resolve().absolute()

files = tuple(self.fe(fileOrDir))

if files:
with ThreadPoolExecutor(max_workers=cpu_count()) as pool:
for f in files:
if self.trace:
print("Processing", f, file=pb)
yield from self.loop.run_in_executor(pool, partial(replaceInEachFileWithContext, f))
if self.trace:
print("Processed", f, file=pb)

noSkipBinary = cli.Flag(
["--no-skip-binary", "-n"],
help="Don't skip binary files. Allows usage without `binaryornot`",
default=False,
)
noSkipDot = cli.Flag(
["--no-skip-dotfiles", "-d"],
help="Don't skip files and dirs which name stem begins from dot.",
default=False,
)
trace = cli.Flag(
["--trace", "-t"],
help="Print info about processing of regular files",
default=False,
)
noReportSkipped = cli.Flag(
["--no-report-skipped", "-s"],
help="Don't report about skipped files",
default=False,
)

def disallowedReportingCallback(self, fileOrDir: Path, reasonOfDisallowal: str) -> None:
if not self.noReportSkipped:
print("Skipping ", fileOrDir, ":", reasonOfDisallowal)

def main(self, *filesOrDirs):
self._repl = None # type: OurInFileReplacer
self.loop = asyncio.get_event_loop()

self.fc = FileClassifier(self.noSkipDot, self.noSkipBinary)
self.fe = FilesEnumerator(self.fc, self.disallowedReportingCallback)

if len(filesOrDirs) == 1 and filesOrDirs[0] == "0":
t = self.loop.create_task(self.asyncMainPathsFromStdIn())
else:
t = self.loop.create_task(self.asyncMainPathsFromCLI(filesOrDirs))
self.loop.run_until_complete(t)


if __name__ == "__main__":
CLI.run()
22 changes: 16 additions & 6 deletions https_everywhere/_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,14 +821,24 @@ def _get_ruleset(hostname, rulesets=None):

logger.debug("no ruleset matches {}".format(hostname))

from icecream import ic

def _remove_trailing_slash(url):
if url[-1] == "/":
url = url[:-1]
return url

def https_url_rewrite(url, rulesets=None):
orig_url = url
if isinstance(url, str):
# In HTTPSEverywhere, URLs must contain a '/'.
if url.replace("http://", "").find("/") == -1:
url += "/"
remove_trailing_slash_if_needed = _remove_trailing_slash
parsed_url = urlparse(url)
else:
remove_trailing_slash_if_needed = lambda x: x

parsed_url = url
if hasattr(parsed_url, "geturl"):
url = parsed_url.geturl()
Expand All @@ -841,19 +851,19 @@ def https_url_rewrite(url, rulesets=None):
ruleset = _get_ruleset(parsed_url.netloc, rulesets)

if not ruleset:
return url
return orig_url

if not isinstance(ruleset, _Ruleset):
ruleset = _Ruleset(ruleset[0], ruleset[1])

if ruleset.exclude_url(url):
return url
return orig_url

# process rules
for rule in ruleset.rules:
logger.debug("checking rule {} -> {}".format(rule[0], rule[1]))
try:
new_url = rule[0].sub(rule[1], url)
count, new_url = rule[0].subn(rule[1], url)
except Exception as e: # pragma: no cover
logger.warning(
"failed during rule {} -> {} , input {}: {}".format(
Expand All @@ -863,7 +873,7 @@ def https_url_rewrite(url, rulesets=None):
raise

# stop if this rule was a hit
if new_url != url:
return new_url
if count:
return remove_trailing_slash_if_needed(new_url)

return url
return orig_url
7 changes: 2 additions & 5 deletions https_everywhere/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from logging_helper import setup_logging

import urllib3
from urllib3.util.url import parse_url

import requests
from requests.adapters import HTTPAdapter
Expand All @@ -13,6 +12,7 @@
from ._chrome_preload_hsts import _preload_including_subdomains
from ._mozilla_preload_hsts import _preload_remove_negative
from ._util import _check_in
from .replacers.HSTSPreloadReplacer import apply_HSTS_preload

PY2 = str != "".__class__
if PY2:
Expand Down Expand Up @@ -155,10 +155,7 @@ def __init__(self, *args, **kwargs):

def get_redirect(self, url):
if url.startswith("http://"):
p = parse_url(url)
if _check_in(self._domains, p.host):
new_url = "https:" + url[5:]
return new_url
return apply_HSTS_preload(url, self._domains)

return super(PreloadHSTSAdapter, self).get_redirect(url)

Expand Down
43 changes: 43 additions & 0 deletions https_everywhere/core/InBufferReplacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import re
import typing

from urllib3.util.url import parse_url

from . import ReplaceContext, SingleURIReplacer

uri_re_source = "(?:http|ftp):\\/\\/?((?:[\\w-]+)(?::[\\w-]+)?@)?[\\w\\.:(-]+(?:\\/[\\w\\.:(/-]*)?"
uri_re_text = re.compile(uri_re_source)
uri_re_binary = re.compile(uri_re_source.encode("ascii"))


class InBufferReplacer(SingleURIReplacer):
__slots__ = ("singleURIReplacer",)
FACS = None

def __init__(self, **kwargs):
self.singleURIReplacer = self.__class__.FACS(**kwargs)

def _rePlaceFuncCore(self, uri):
ctx = ReplaceContext(uri)
self.singleURIReplacer(ctx)
return ctx

def _rePlaceFuncText(self, m):
uri = m.group(0)
ctx = self._rePlaceFuncCore(uri)
if ctx.count > 0:
return ctx.res
return uri

def _rePlaceFuncBinary(self, m):
uri = m.group(0)
ctx = self._rePlaceFuncCore(uri.decode("utf-8"))
if ctx.count > 0:
return ctx.res.encode("utf-8")
return uri

def __call__(self, inputStr: typing.Union[str, bytes]) -> ReplaceContext:
if isinstance(inputStr, str):
return ReplaceContext(*uri_re_text.subn(self._rePlaceFuncText, inputStr))
else:
return ReplaceContext(*uri_re_binary.subn(self._rePlaceFuncBinary, inputStr))
Loading

0 comments on commit 315a627

Please sign in to comment.