diff --git a/python-threatexchange/threatexchange/cli/hash_cmd.py b/python-threatexchange/threatexchange/cli/hash_cmd.py index 5d769ae79..73ab279cd 100644 --- a/python-threatexchange/threatexchange/cli/hash_cmd.py +++ b/python-threatexchange/threatexchange/cli/hash_cmd.py @@ -9,12 +9,14 @@ import pathlib import typing as t import tempfile +from pathlib import Path from threatexchange import common from threatexchange.cli.cli_config import CLISettings from threatexchange.cli.exceptions import CommandError from threatexchange.content_type.content_base import ContentType from threatexchange.content_type.photo import PhotoContent +from threatexchange.content_type.content_base import RotationType from threatexchange.signal_type.signal_base import FileHasher, SignalType from threatexchange.cli import command_base @@ -53,6 +55,7 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No signal_choices = sorted( s.get_name() for s in signal_types if issubclass(s, FileHasher) ) + ap.add_argument( "content_type", **common.argparse_choices_pre_type_kwargs( @@ -80,10 +83,29 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No ) ap.add_argument( - "--rotations", - "--R", + "--photo-preprocess", + choices=["unletterbox", "rotations"], + help=( + "Apply one of the preprocessing steps to the image before hashing. " + "'unletterbox' removes black borders, and 'rotations' generates all 8 " + "simple rotations." + ), + ) + + ap.add_argument( + "--black-threshold", + type=int, + default=15, + help=( + "Set the black threshold for unletterboxing (default: 15)." + "Only applies when 'unletterbox' is selected in --preprocess." + ), + ) + + ap.add_argument( + "--save-preprocess", action="store_true", - help="for photos, generate all 8 simple rotations", + help="save the preprocessed image data as new files", ) def __init__( @@ -91,14 +113,20 @@ def __init__( content_type: t.Type[ContentType], signal_type: t.Optional[t.Type[SignalType]], files: t.List[pathlib.Path], - rotations: bool = False, + photo_preprocess: t.Optional[str] = None, + black_threshold: int = 0, + save_preprocess: bool = False, ) -> None: self.content_type = content_type self.signal_type = signal_type - + self.photo_preprocess = photo_preprocess + self.black_threshold = black_threshold + self.save_preprocess = save_preprocess self.files = files - - self.rotations = rotations + if self.photo_preprocess and not issubclass(self.content_type, PhotoContent): + raise CommandError( + "--photo-preprocess flag is only available for Photo content type", 2 + ) def execute(self, settings: CLISettings) -> None: hashers = [ @@ -115,7 +143,7 @@ def execute(self, settings: CLISettings) -> None: hashers = [self.signal_type] # type: ignore # can't detect intersection types - if not self.rotations: + if not self.photo_preprocess: for file in self.files: for hasher in hashers: hash_str = hasher.hash_from_file(file) @@ -123,20 +151,50 @@ def execute(self, settings: CLISettings) -> None: print(hasher.get_name(), hash_str) return - if not issubclass(self.content_type, PhotoContent): - raise CommandError( - "--rotations flag is only available for Photo content type", 2 - ) - - for file in self.files: - with open(file, "rb") as f: - image_bytes = f.read() - rotated_images = PhotoContent.all_simple_rotations(image_bytes) - for rotation_type, rotated_bytes in rotated_images.items(): - with tempfile.NamedTemporaryFile() as temp_file: # Create a temporary file to hold the byte data - temp_file.write(rotated_bytes) - temp_file_path = pathlib.Path(temp_file.name) - for hasher in hashers: - hash_str = hasher.hash_from_file(temp_file_path) - if hash_str: - print(rotation_type.name, hasher.get_name(), hash_str) + def pre_processed_files() -> ( + t.Iterator[t.Tuple[Path, bytes, t.Union[None, RotationType], str]] + ): + """ + Generator that yields preprocessed files and their metadata. + Each item is a tuple of (file path, processed bytes, rotation name, image format). + """ + for file in self.files: + image_format = file.suffix.lower().lstrip(".") + if self.photo_preprocess == "unletterbox": + processed_bytes = PhotoContent.unletterbox( + file, self.black_threshold + ) + yield file, processed_bytes, None, image_format + elif self.photo_preprocess == "rotations": + with open(file, "rb") as f: + image_bytes = f.read() + rotations = PhotoContent.all_simple_rotations(image_bytes) + for rotation_type, processed_bytes in rotations.items(): + yield file, processed_bytes, rotation_type, image_format + + for ( + file, + processed_bytes, + rotation_type, + image_format, + ) in pre_processed_files(): + output_extension = f".{image_format.lower()}" if image_format else ".png" + with tempfile.NamedTemporaryFile( + delete=not self.save_preprocess, suffix=output_extension + ) as temp_file: + temp_file.write(processed_bytes) + temp_file_path = Path(temp_file.name) + for hasher in hashers: + hash_str = hasher.hash_from_file(temp_file_path) + if hash_str: + prefix = rotation_type.name if rotation_type else "" + print(f"{prefix} {hasher.get_name()} {hash_str}") + if self.save_preprocess: + suffix = ( + f"_{rotation_type.name}" if rotation_type else "_unletterboxed" + ) + output_path = file.with_stem(f"{file.stem}{suffix}").with_suffix( + output_extension + ) + temp_file_path.rename(output_path) + print(f"Processed image saved to: {output_path}") diff --git a/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py b/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py index 0eaf96a1e..ea9a32d42 100644 --- a/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py +++ b/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py @@ -83,17 +83,20 @@ def test_rotations_with_non_photo_content( """Test that rotation flag raises error with non-photo content""" for content_type in ["url", "text", "video"]: hash_cli.assert_cli_usage_error( - ("--rotations", content_type, str(tmp_file)), - msg_regex="--rotations flag is only available for Photo content type", + ("--photo-preprocess=rotations", content_type, str(tmp_file)), + msg_regex="--photo-preprocess flag is only available for Photo content type", ) def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): """Test that photo rotations are properly processed""" - test_file = pathlib.Path("threatexchange/tests/hashing/resources/LA.png") + resources_dir = ( + pathlib.Path(__file__).parent.parent.parent / "tests/hashing/resources" + ) + test_file = resources_dir / "LA.png" hash_cli.assert_cli_output( - ("--rotations", "photo", str(test_file)), + ("--photo-preprocess=rotations", "photo", str(test_file)), [ "ORIGINAL pdq accb6d39648035f8125c8ce6ba65007de7b54c67a2d93ef7b8f33b0611306715", "ROTATE90 pdq 1f70cbbc77edc5f9524faa1b18f3b76cd0a04a833e20f645d229d0acc8499c56", @@ -105,3 +108,50 @@ def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): "FLIPMINUS1 pdq 5bb15db9e8a1f03c174a380a55aeaa2985bde9c60abce301bde48df918b5c15b", ], ) + + +def test_unletterbox_with_non_photo_content( + hash_cli: ThreatExchangeCLIE2eHelper, tmp_file: pathlib.Path +): + """Test that unletterbox flag raises error with non-photo content""" + for content_type in ["url", "text", "video"]: + hash_cli.assert_cli_usage_error( + ("--photo-preprocess=unletterbox", content_type, str(tmp_file)), + msg_regex="--photo-preprocess flag is only available for Photo content type", + ) + + +def test_unletterbox_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): + """Test that photo unletterboxing is properly processed""" + resources_dir = ( + pathlib.Path(__file__).parent.parent.parent / "tests/hashing/resources" + ) + test_file = resources_dir / "letterboxed_sample-b.jpg" + clean_file = resources_dir / "sample-b.jpg" + + hash_cli.assert_cli_output( + ("photo", str(clean_file)), + [ + "pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", + ], + ) + + """Test that photo unletterboxing is changed based on allowed threshold""" + hash_cli.assert_cli_output( + ("--photo-preprocess=unletterbox", "photo", str(test_file)), + [ + "pdq d8f871cce0f4e84d8a370a32028f63f4b36e27d597621e1d33e6b39c4a9c9b22", + ], + ) + + hash_cli.assert_cli_output( + ( + "--photo-preprocess=unletterbox", + "--black-threshold=25", + "photo", + str(test_file), + ), + [ + "pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", + ], + ) diff --git a/python-threatexchange/threatexchange/content_type/photo.py b/python-threatexchange/threatexchange/content_type/photo.py index aa0f7e922..10f2c1c73 100644 --- a/python-threatexchange/threatexchange/content_type/photo.py +++ b/python-threatexchange/threatexchange/content_type/photo.py @@ -5,10 +5,12 @@ Wrapper around the video content type. """ from PIL import Image +from pathlib import Path import io import typing as t from .content_base import ContentType, RotationType +from threatexchange.content_type.preprocess import unletterboxing class PhotoContent(ContentType): @@ -102,3 +104,24 @@ def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]: RotationType.FLIPMINUS1: cls.flip_minus1(image_data), } return rotations + + @classmethod + def unletterbox(cls, file_path: Path, black_threshold: int = 0) -> bytes: + """ + Remove black letterbox borders from the sides and top of the image based on the specified black_threshold. + Returns the cleaned image as raw bytes. + """ + with file_path.open("rb") as file: + with Image.open(file) as image: + img = image.convert("RGB") + top = unletterboxing.detect_top_border(img, black_threshold) + bottom = unletterboxing.detect_bottom_border(img, black_threshold) + left = unletterboxing.detect_left_border(img, black_threshold) + right = unletterboxing.detect_right_border(img, black_threshold) + + width, height = image.size + cropped_img = image.crop((left, top, width - right, height - bottom)) + + with io.BytesIO() as buffer: + cropped_img.save(buffer, format=image.format) + return buffer.getvalue() diff --git a/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py b/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py new file mode 100644 index 000000000..ea670c17b --- /dev/null +++ b/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# Copyright (c) Meta Platforms, Inc. and affiliates. +from PIL import Image + + +def is_pixel_black(pixel: tuple, black_threshold: int): + """ + Check if each color channel in the pixel is below the threshold + """ + r, g, b = pixel + return r <= black_threshold and g <= black_threshold and b <= black_threshold + + +def detect_top_border(image: Image.Image, black_threshold: int = 0) -> int: + """ + Detect the top black border by counting rows with only black pixels. + Checks each RGB channel of each pixel in each row. + Returns the first row that is not all black from the top. + """ + width, height = image.size + for y in range(height): + row_pixels = list(image.crop((0, y, width, y + 1)).getdata()) + if all(is_pixel_black(pixel, black_threshold) for pixel in row_pixels): + continue + return y + return height + + +def detect_bottom_border(image: Image.Image, black_threshold: int = 0) -> int: + """ + Detect the bottom black border by counting rows with only black pixels from the bottom up. + Checks each RGB channel of each pixel in each row. + Returns the first row that is not all black from the bottom. + """ + width, height = image.size + for y in range(height - 1, -1, -1): + row_pixels = list(image.crop((0, y, width, y + 1)).getdata()) + if all(is_pixel_black(pixel, black_threshold) for pixel in row_pixels): + continue + return height - y - 1 + return height + + +def detect_left_border(image: Image.Image, black_threshold: int = 0) -> int: + """ + Detect the left black border by counting columns with only black pixels. + Checks each RGB channel of each pixel in each column. + Returns the first column from the left that is not all black. + """ + width, height = image.size + for x in range(width): + col_pixels = list(image.crop((x, 0, x + 1, height)).getdata()) + if all(is_pixel_black(pixel, black_threshold) for pixel in col_pixels): + continue + return x + return width + + +def detect_right_border(image: Image.Image, black_threshold: int = 0) -> int: + """ + Detect the right black border by counting columns with only black pixels from the right. + Checks each RGB channel of each pixel in each column. + Returns the first column from the right that is not all black. + """ + width, height = image.size + for x in range(width - 1, -1, -1): + col_pixels = list(image.crop((x, 0, x + 1, height)).getdata()) + if all(is_pixel_black(pixel, black_threshold) for pixel in col_pixels): + continue + return width - x - 1 + return width diff --git a/python-threatexchange/threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg b/python-threatexchange/threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg new file mode 100644 index 000000000..d2e23eb6c Binary files /dev/null and b/python-threatexchange/threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg differ diff --git a/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg b/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg new file mode 100644 index 000000000..66ad092df Binary files /dev/null and b/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg differ