Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] embeded tx hash for unletterboxing #1684

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions python-threatexchange/threatexchange/cli/hash_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from threatexchange.signal_type.signal_base import FileHasher, SignalType
from threatexchange.cli import command_base
from threatexchange.cli.helpers import FlexFilesInputAction
from threatexchange.signal_type.pdq.signal import PdqSignal


class HashCommand(command_base.Command):
Expand Down Expand Up @@ -79,6 +80,26 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
help="only generate these signal types",
)

ap.add_argument(
"--preprocess",
Dcallies marked this conversation as resolved.
Show resolved Hide resolved
choices=["unletterbox"],
help="Apply preprocessing steps to the image before hashing.",
)

ap.add_argument(
"--black-threshold",
type=int,
default=40,
help="Set the black threshold for unletterboxing. Default is 40.",
)

ap.add_argument(
"--save-output",
type=bool,
default=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If you make the action store_true then the default is false IIRC.

help="If true, save the processed image as a new file.",
)

ap.add_argument(
"--rotations",
Copy link
Contributor

@Dcallies Dcallies Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no changes needed): Do you think we should combine rotations into your generalization of preprocessing?

Alternatively, what do you think about making rotation and unletterboxing mutually exclusive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought that it made the most sense to combine but didn't want to jump ahead and do it beforehand but I can add it to the list of actions.

Also, would there ever be a workflow where someone may want to both check for rotation and process the image for letterboxing?

"--R",
Expand All @@ -92,10 +113,15 @@ def __init__(
signal_type: t.Optional[t.Type[SignalType]],
files: t.List[pathlib.Path],
rotations: bool = False,
preprocess: t.Optional[str] = None,
black_threshold: int = 40,
save_output: bool = False,
) -> None:
self.content_type = content_type
self.signal_type = signal_type

self.preprocess = preprocess
self.black_threshold = black_threshold
self.save_output = save_output
self.files = files

self.rotations = rotations
Expand All @@ -118,7 +144,17 @@ def execute(self, settings: CLISettings) -> None:
if not self.rotations:
for file in self.files:
for hasher in hashers:
hash_str = hasher.hash_from_file(file)
if isinstance(hasher, PdqSignal) and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only PdqSignal? Wouldn't other image perceptual hashing algorithms benefit from this?

We also generally want to avoid places where we do isinstance(<interface_implementation>) in preference of it being handled by the interface itself.

Copy link
Contributor Author

@Mackay-Fisher Mackay-Fisher Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I thought it was specific to PdqSignal and as for why I bypassed the interface it is because they do not all have the method hash from bytes and I did not always want to create the new file with updated images bytes. Is there a way I can go around this or would it be better to always create the new file even if temporarily, and then save the output if the user passes the flag to save it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not specific to PdqSignal, and the hash_from_bytes method is itself part of a wider interface.

I think I eventually decided it was simpler to write everything to tmpfiles, which is how we ended up with the current implementation.

However, similar to the feedback I gave for --rotations, we can make our life a lot easier by having the preprocessing happen in between the file input and the for file in self.files.

I like your idea of providing a way to pass flag to save it, but let's save that for a followup.

self.content_type.get_name() == "photo"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, more specialization.

and self.preprocess == "unletterbox"
):
hash_str = PdqSignal.hash_from_bytes(
PhotoContent.unletterbox(
file, self.save_output, self.black_threshold
)
)
else:
hash_str = hasher.hash_from_file(file)
if hash_str:
print(hasher.get_name(), hash_str)
return
Expand All @@ -130,7 +166,15 @@ def execute(self, settings: CLISettings) -> None:

for file in self.files:
with open(file, "rb") as f:
image_bytes = f.read()
if (
self.content_type.get_name() == "photo"
and self.preprocess == "unletterbox"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Instead, in the __init__ validate that unletterbox can only be used with photo, which will save you many future checks like this.

):
image_bytes = PhotoContent.unletterbox(
file, self.save_output, self.black_threshold
)
else:
image_bytes = f.read()
rotated_images = PhotoContent.all_simple_rotations(image_bytes)
for rotation_type, rotated_bytes in rotated_images.items():
with tempfile.NamedTemporaryFile() as temp_file: # Create a temporary file to hold the byte data
Expand Down
119 changes: 119 additions & 0 deletions python-threatexchange/threatexchange/content_type/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
Wrapper around the video content type.
"""
from PIL import Image
from pathlib import Path
import io
import os

from .content_base import ContentType, RotationType

Expand Down Expand Up @@ -101,3 +103,120 @@ def all_simple_rotations(cls, image_data: bytes):
RotationType.FLIPMINUS1: cls.flip_minus1(image_data),
}
return rotations

@classmethod
def detect_top_border(
cls, grayscale_img: Image.Image, black_threshold: int = 10
) -> int:
"""
Detect the top black border by counting rows with only black pixels.
Uses a defualt black threshold of 10 so that only rows with pixel brightness
of 10 or lower will be removed.

Returns the first row that is not all blacked out from the top.
"""
width, height = grayscale_img.size
for y in range(height):
row_pixels = list(grayscale_img.crop((0, y, width, y + 1)).getdata())
if all(pixel < black_threshold for pixel in row_pixels):
continue
return y
return height

@classmethod
def detect_bottom_border(
cls, grayscale_img: Image.Image, black_threshold: int = 10
) -> int:
"""
Detect the bottom black border by counting rows with only black pixels from the bottom up.
Uses a defualt black threshold of 10 so that only rows with pixel brightness
of 10 or lower will be removed.

Returns the first row that is not all blacked out from the bottom.
"""
width, height = grayscale_img.size
for y in range(height - 1, -1, -1):
row_pixels = list(grayscale_img.crop((0, y, width, y + 1)).getdata())
if all(pixel < black_threshold for pixel in row_pixels):
continue
return height - y - 1
return height

@classmethod
def detect_left_border(
cls, grayscale_img: Image.Image, black_threshold: int = 10
) -> int:
"""
Detect the left black border by counting columns with only black pixels.
Uses a defualt black threshold of 10 so that only colums with pixel brightness
of 10 or lower will be removed.

Returns the first column from the left that is not all blacked out in the column.
"""
width, height = grayscale_img.size
for x in range(width):
col_pixels = list(grayscale_img.crop((x, 0, x + 1, height)).getdata())
if all(pixel < black_threshold for pixel in col_pixels):
continue
return x
return width

@classmethod
def detect_right_border(
cls, grayscale_img: Image.Image, black_threshold: int = 10
) -> int:
"""
Detect the right black border by counting columns with only black pixels from the right.
Uses a defualt black threshold of 10 so that only colums with pixel brightness
of 10 or lower will be removed.

Returns the first column from the right that is not all blacked out in the column.
"""
width, height = grayscale_img.size
for x in range(width - 1, -1, -1):
col_pixels = list(grayscale_img.crop((x, 0, x + 1, height)).getdata())
if all(pixel < black_threshold for pixel in col_pixels):
continue
return width - x - 1
return width
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By putting these all at the top level, we are signaling that they are part of the "official" interface for photos.

Instead, let's move this functionality into its own file in a new /preprocessing directory. We can add unletterbox.py as its own module, with these 4 methods then as standalone.


@classmethod
def unletterbox(
cls, file_path: Path, save_output: bool = False, black_threshold: int = 40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: Instead of making save_output an argument here, it might be better to compose it from the outside by taking the bytes, which will give the caller more control over the file directory.

blocking q: Can you tell more about how you picked 40? We may want to be very conservative by default (even to only 100% black).

) -> bytes:
"""
Remove black letterbox borders from the sides and top of the image.

Converts the image to grescale then remove the columns and rows that
are all completly blacked out.

Then removing the edges to give back a cleaned image bytes.

Return the new hash of the cleaned image with an option to create a new output file as well
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this returns the hash, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh no I had it returning the hash at first but then it created a circular dependency so I removed it but did not update the comment. I will update.

"""
# Open the original image
with Image.open(file_path) as img:
grayscale_img = img.convert("L")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking q: Hmm, why convert to grayscale first? Won't think convert some full colors to black?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revised this and updated to check each individual value of the r g b pixel


top = cls.detect_top_border(grayscale_img, black_threshold)
bottom = cls.detect_bottom_border(grayscale_img, black_threshold)
left = cls.detect_left_border(grayscale_img, black_threshold)
right = cls.detect_right_border(grayscale_img, black_threshold)

width, height = grayscale_img.size
cropped_box = (left, top, width - right, height - bottom)

cropped_img = img.crop(cropped_box)

# Optionally save the unletterboxed image to a new file in the same directory
if save_output:
path = Path(file_path)
output_path = path.parent / f"{path.stem}_unletterboxed{path.suffix}"
cropped_img.save(output_path)
print(f"Unletterboxed image saved to: {output_path}")

# Convert the cropped image to bytes for hashing
with io.BytesIO() as buffer:
cropped_img.save(buffer, format=img.format)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why .img? Should we use the same format that was passed in?

cropped_image_data = buffer.getvalue()
return cropped_image_data
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest
from pathlib import Path
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.content_type.photo import PhotoContent


class TestUnletterboxFunction(unittest.TestCase):
def setUp(self):
# Load the file paths
current_path = Path(__file__).parent
self.letterbox_path = Path(f"{current_path}/resources/letterbox.png")
self.clean_path = Path(f"{current_path}/resources/clean.png")
self.output_path = Path(f"{current_path}/resources/letterbox_unletterboxed.png")

def clean(self):
# Removes generated output file if already exists
if self.output_path.exists():
self.output_path.unlink()

def test_letterbox_image_without_unletterbox(self):
with self.letterbox_path.open("rb") as f:
letterbox_data = f.read()

letterbox_hash = PdqSignal.hash_from_bytes(letterbox_data)

with self.clean_path.open("rb") as f:
clean_data = f.read()
clean_hash = PdqSignal.hash_from_bytes(clean_data)

# Assert that the hash of the original letterbox image is different from the clean image's hash
self.assertNotEqual(
letterbox_hash,
clean_hash,
"Letterbox image unexpectedly matches the clean image",
)

def test_unletterbox_image(self):
# Generate PDQ hash for the unletterboxed image
unletterboxed_hash = PdqSignal.hash_from_bytes(
PhotoContent.unletterbox(self.letterbox_path)
)

# Read the clean image data and generate PDQ hash
with self.clean_path.open("rb") as f:
clean_data = f.read()
clean_hash = PdqSignal.hash_from_bytes(clean_data)

self.assertEqual(
unletterboxed_hash,
clean_hash,
"Unletterboxed image does not match the clean image",
)

def test_unletterboxfile_creates_matching_image(self):
# Created generated hash and also create new output file
generated_hash = PdqSignal.hash_from_bytes(
PhotoContent.unletterbox(self.letterbox_path, True)
)
self.assertTrue(
self.output_path.exists(), "The unletterboxed output file was not created."
)

# Generate PDQ hash for the clean image
with self.clean_path.open("rb") as f:
clean_data = f.read()
clean_hash = PdqSignal.hash_from_bytes(clean_data)

# Assert that the hash of the generated unletterboxed image matches the clean image's hash
self.assertEqual(
generated_hash,
clean_hash,
"Unletterboxfile output does not match the clean image",
)

# Removes created file
if self.output_path.exists():
self.output_path.unlink()


if __name__ == "__main__":
unittest.main()