diff --git a/pillow_jxl/JpegXLImagePlugin.py b/pillow_jxl/JpegXLImagePlugin.py index 3579c13..c5adb55 100644 --- a/pillow_jxl/JpegXLImagePlugin.py +++ b/pillow_jxl/JpegXLImagePlugin.py @@ -107,6 +107,7 @@ def _save(im, fp, filename, save_all=False): use_original_profile = info.get("use_original_profile", False) jpeg_encode = info.get("lossless_jpeg", None) num_threads = info.get("num_threads", -1) + compress_metadata = info.get("compress_metadata", False) enc = Encoder( mode=im.mode, @@ -140,6 +141,7 @@ def _save(im, fp, filename, save_all=False): "exif": exif or None, "jumb": info.get("jumb") or None, "xmp": info.get("xmp") or None, + "compress": compress_metadata, } data = enc(im.tobytes(), im.width, im.height, jpeg_encode=False, **metadata) fp.write(data) diff --git a/pyproject.toml b/pyproject.toml index e3d93b0..e1f0681 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["numpy", "pytest"] +dev = ["numpy", "pytest", "pyexiv2"] [project.urls] "Homepage" = "https://github.com/Isotr0py/pillow-jpegxl-plugin" diff --git a/src/encode.rs b/src/encode.rs index 70fa771..defa0e8 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -73,7 +73,7 @@ impl Encoder { }) } - #[pyo3(signature = (data, width, height, jpeg_encode, exif=None, jumb=None, xmp=None))] + #[pyo3(signature = (data, width, height, jpeg_encode, exif=None, jumb=None, xmp=None, compress=false))] fn __call__( &self, py: Python, @@ -84,8 +84,11 @@ impl Encoder { exif: Option<&[u8]>, jumb: Option<&[u8]>, xmp: Option<&[u8]>, + compress: bool, ) -> PyResult> { - py.allow_threads(|| self.call_inner(data, width, height, jpeg_encode, exif, jumb, xmp)) + py.allow_threads(|| { + self.call_inner(data, width, height, jpeg_encode, exif, jumb, xmp, compress) + }) } fn __repr__(&self) -> PyResult { @@ -106,6 +109,7 @@ impl Encoder { exif: Option<&[u8]>, jumb: Option<&[u8]>, xmp: Option<&[u8]>, + compress: bool, ) -> PyResult> { let parallel_runner = ThreadsRunner::new( None, @@ -149,17 +153,17 @@ impl Encoder { let frame = EncoderFrame::new(data).num_channels(self.num_channels); if let Some(exif_data) = exif { encoder - .add_metadata(&Metadata::Exif(exif_data), true) + .add_metadata(&Metadata::Exif(exif_data), compress) .map_err(to_pyjxlerror)? } if let Some(xmp_data) = xmp { encoder - .add_metadata(&Metadata::Xmp(xmp_data), true) + .add_metadata(&Metadata::Xmp(xmp_data), compress) .map_err(to_pyjxlerror)? } if let Some(jumb_data) = jumb { encoder - .add_metadata(&Metadata::Jumb(jumb_data), true) + .add_metadata(&Metadata::Jumb(jumb_data), compress) .map_err(to_pyjxlerror)? } encoder diff --git a/test/test_plugin.py b/test/test_plugin.py index 53741c8..9ddba57 100644 --- a/test/test_plugin.py +++ b/test/test_plugin.py @@ -1,5 +1,6 @@ import tempfile +import pyexiv2 import pytest import numpy as np from PIL import Image @@ -86,31 +87,40 @@ def test_metadata_decode(): def test_metadata_encode_from_jpg(): # Load a JPEG image - img_ori = Image.open("test/images/metadata/1x1_exif_xmp.jpg") + ref_img_path = "test/images/metadata/1x1_exif_xmp.jpg" temp = tempfile.mktemp(suffix=".jxl") + img_ori = Image.open(ref_img_path) img_ori.save(temp, use_container=True) img_enc = Image.open(temp) + img_enc_exiv2 = pyexiv2.Image(temp) + img_ori_exiv2 = pyexiv2.Image(ref_img_path) assert img_ori.getexif() == img_enc.getexif() + assert img_ori_exiv2.read_exif() == img_enc_exiv2.read_exif() -@pytest.mark.skip(reason="Broken test") def test_metadata_encode_from_raw_exif(): with open("test/images/metadata/sample.exif", "rb") as f: ref_exif = f.read() img_ori = Image.open("test/images/sample.png") temp = tempfile.mktemp(suffix=".jxl") - img_ori.save(temp, exif=ref_exif, use_container=True) + img_ori.save(temp, exif=ref_exif) - img_enc = Image.open(temp) - assert ref_exif == img_enc.getexif().tobytes() + ref_exif = pyexiv2.ImageData(ref_exif).read_exif() + jxl_exif = pyexiv2.Image(temp).read_exif() + assert ref_exif == jxl_exif -@pytest.mark.skip(reason="Broken test") def test_metadata_encode_from_pil_exif(): - img_ori = Image.open("test/images/metadata/1x1_exif_xmp.png") + exif_img_path = "test/images/metadata/1x1_exif_xmp.jpg" + dummy_img = Image.open("test/images/sample.png") + exif_img = Image.open(exif_img_path) temp = tempfile.mktemp(suffix=".jxl") - img_ori.save(temp, exif=img_ori.getexif().tobytes(), use_container=True) - - img_enc = Image.open(temp) - assert img_ori.getexif().tobytes() == img_enc.getexif().tobytes() + dummy_img.save(temp, exif=exif_img.getexif().tobytes()) + + ref_exif = pyexiv2.Image(exif_img_path).read_exif() + jxl_exif = pyexiv2.Image(temp).read_exif() + for key in ref_exif: + # Skip UserComment and GPSAltitude as they are broken + if key not in ("Exif.Photo.UserComment", 'Exif.GPSInfo.GPSAltitude'): + assert ref_exif[key] == jxl_exif[key] diff --git a/test_exif.py b/test_exif.py new file mode 100644 index 0000000..c977f53 --- /dev/null +++ b/test_exif.py @@ -0,0 +1,18 @@ +import tempfile +import pyexiv2 +from PIL import Image + +import pillow_jxl + +with open("test/images/metadata/sample.exif", "rb") as f: + ref_exif = f.read() +img_ori = Image.open("test/images/sample.png") +temp = tempfile.mktemp(suffix=".jxl") + +img_ori.save(temp, exif=ref_exif, use_container=True) +img_enc = pyexiv2.Image(temp) +print(img_enc.read_exif()) +# img_enc = pyexiv2.Image("test/images/metadata/1x1_exif_xmp.jxl") +# print(img_enc.read_exif()) +# img_enc = pyexiv2.Image("test/images/metadata/1x1_exif_xmp.jpg") +# print(img_enc.read_exif())