diff --git a/ImageGoNord/GoNord.py b/ImageGoNord/GoNord.py index 70e7616..3c47d90 100755 --- a/ImageGoNord/GoNord.py +++ b/ImageGoNord/GoNord.py @@ -12,11 +12,14 @@ import ffmpeg import uuid import shutil +import requests -import torch -import skimage.io as io -import skimage.color as convertor -import torchvision.transforms as transforms +try: + import torch + import skimage.color as convertor + import torchvision.transforms as transforms +except ImportError: + print("Please install the dependencies required for the AI feature") try: @@ -31,7 +34,11 @@ from ImageGoNord.utility.quantize import quantize_to_palette import ImageGoNord.utility.palette_loader as pl from ImageGoNord.utility.ConvertUtility import ConvertUtility -from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder + +try: + from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder +except ImportError: + print("Please install the dependencies required for the AI feature") class NordPaletteFile: @@ -158,6 +165,8 @@ class GoNord(object): TRANSPARENCY_TOLERANCE = 190 MAX_THREADS = 10 + PALETTE_NET_REPO_FOLDER = 'https://github.com/Schrodinger-Hat/ImageGoNord-pip/raw/master/ImageGoNord/models/PaletteNet/' + AVAILABLE_PALETTE = [] PALETTE_DATA = {} @@ -425,6 +434,16 @@ def converted_loop(self, is_rgba, pixels, original_pixels, maxRow, maxCol, minRo pixels[row, col] = tuple(colors_list) return pixels + def load_and_save_models(self): + rd_model = requests.get(self.PALETTE_NET_REPO_FOLDER + 'RD.state_dict.pt') + fe_model = requests.get(self.PALETTE_NET_REPO_FOLDER + 'FE.state_dict.pt') + + with open(os.path.dirname(palette_net.__file__) + '/FE.state_dict.pt', "wb") as f: + f.write(fe_model.content) + + with open(os.path.dirname(palette_net.__file__) + '/RD.state_dict.pt', "wb") as f: + f.write(rd_model.content) + def convert_image_by_model(self, image, use_model_cpu=False): """ Process a Pillow image by using a PyTorch model "PaletteNet" for recoloring the image @@ -444,8 +463,14 @@ def convert_image_by_model(self, image, use_model_cpu=False): FE = FeatureEncoder() # torch.Size([64, 3, 3, 3]) RD = RecoloringDecoder() # torch.Size([530, 256, 3, 3]) - FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt"))) - RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt"))) + if ( + os.path.exists(os.path.dirname(palette_net.__file__) + '/FE.state_dict.pt') + and os.path.exists(os.path.dirname(palette_net.__file__) + '/RD.state_dict.pt') + ): + FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt"))) + RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt"))) + else: + self.load_and_save_models() if use_model_cpu: FE.to("cpu") @@ -472,7 +497,8 @@ def convert_image_by_model(self, image, use_model_cpu=False): try: pal_np = np.array(palette).reshape(1,6,3)/255 except: - print("You have too many colors in your palette for the model, this feature is limited to 6 colours, now you have: ", len(palette), "! I'll take the first 6!") + # this feature is limited to 6 colours + # we're taking the first six pal_np = np.array(palette[0:6]).reshape(1,6,3)/255 pal = torch.Tensor((convertor.rgb2lab(pal_np) - [50,0,0] ) / [50,128,128]).unsqueeze(0) @@ -517,7 +543,7 @@ def convert_image(self, image, save_path='', use_model=False, use_model_cpu=Fals pixels = self.load_pixel_image(image) is_rgba = (image.mode == 'RGBA') - if use_model: + if use_model and torch != None: image = self.convert_image_by_model(image, use_model_cpu) else: if not parallel_threading: diff --git a/ImageGoNord/__init__.py b/ImageGoNord/__init__.py index 5d8a7f1..9871840 100755 --- a/ImageGoNord/__init__.py +++ b/ImageGoNord/__init__.py @@ -1,4 +1,4 @@ # gonord version -__version__ = "1.0.2" +__version__ = "1.1.0" from ImageGoNord.GoNord import * \ No newline at end of file diff --git a/index.py b/index.py index 22de70e..c26353a 100755 --- a/index.py +++ b/index.py @@ -22,12 +22,10 @@ # go_nord.add_file_to_palette(NordPaletteFile.AURORA) # go_nord.add_file_to_palette(NordPaletteFile.FROST) -# image = go_nord.open_image("images/valley.jpg") -# go_nord.convert_image(image, save_path="images/test-valley-ai.jpg", use_model=True) - -output_path = go_nord.convert_video('videos/SampleVideo_720x480.mp4', 'custom_palette', save_path='videos/SampleVideo_converted.mp4') -print(output_path) +image = go_nord.open_image("images/valley.jpg") +go_nord.convert_image(image, save_path="images/test-valley-ai.jpg", use_model=True) exit() +# output_path = go_nord.convert_video('videos/SampleVideo_720x480.mp4', 'custom_palette', save_path='videos/SampleVideo_converted.mp4') image = go_nord.open_image("images/test.jpg") resized_img = go_nord.resize_image(image) diff --git a/setup.py b/setup.py index a7495fa..7264b8b 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="image-go-nord", - version="1.0.2", + version="1.1.0", description="A tool to convert any RGB image or video to any theme or color palette input by the user", long_description=README, long_description_content_type="text/markdown", @@ -17,7 +17,7 @@ author_email="schrodinger.hat.show@gmail.com", license="AGPL-3.0", classifiers=[ - 'Development Status :: 5 - Production/Stable', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package + 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", @@ -30,8 +30,11 @@ "Bug Reports": "https://github.com/Schrodinger-Hat/ImageGoNord-pip/issues", }, packages=find_packages(), - package_data={'': ['*.txt', 'palettes/*.txt', 'models/*.pt', '*.pt', '*.state_dict.*']}, + package_data={'': ['*.txt', 'palettes/*.txt']}, include_package_data=True, - install_requires=["Pillow", "ffmpeg-python", "numpy", "torch", "scikit-image", "torchvision"], + install_requires=["Pillow", "ffmpeg-python", "numpy", "requests"], + extras_require = { + 'AI': ["torch", "scikit-image", "torchvision"] + }, python_requires=">=3.5" )