This is the 3d-part realisation of Color2Embed: Fast Exemplar-Based Image Colorization using Color Embeddings (https://arxiv.org/abs/2106.08017) on PyTorch
The article has several ambiguous points that I decided as best I could, in addition, I moved away from the things clearly described in the article: I changed the color encoder to resnet18 instead of simple network in the article and increased the learning rate by 10 times (from 10-4 to 10-3), btw results look like very similar with authors results.
In this paper, we present a fast exemplar-based image colorization approach using color embeddings named Color2Embed. Generally, due to the difficulty of obtaining input and ground truth image pairs, it is hard to train a exemplar-based colorization model with unsupervised and unpaired training manner. Current algorithms usually strive to achieve two procedures: i) retrieving a large number of reference images with high similarity for preparing training dataset, which is inevitably time-consuming and tedious; ii) designing complicated modules to transfer the colors of the reference image to the target image, by calculating and leveraging the deep semantic correspondence between them (e.g., non-local operation), which is computationally expensive during testing. Contrary to the previous methods, we adopt a self-augmented self-reference learning scheme, where the reference image is generated by graphical transformations from the original colorful one whereby the training can be formulated in a paired manner. Second, in order to reduce the process time, our method explicitly extracts the color embeddings and exploits a progressive style feature Transformation network, which injects the color embeddings into the reconstruction of the final image. Such design is much more lightweight and intelligible, achieving appealing performance with fast processing speed.
- Setup enviroment, my (a little redundant) conda enviroment in
environment.yml
- Download ImageNet part from kaggle
- Unpack it and made train list: each line of simple text file should be path to file, e.g.
n09428293/n09428293_23938.JPEG
or full path, set up path to this text file inconfig.IMAGENET_LIST
. If you have full path in list, setconfig.IMAGENET_PREFIX = ''
or to prefix if you extract it, btwos.path.join(config.IMAGENET_PREFIX, <image_list_item>)
should be valid path to image - Change config batch size for your GPU (I have V100 and batch 32, it is around 20 GB GPU memory)
- Run training with DDP from
color2embed
folderpython -m torch.distributed.launch --nproc_per_node=8 train.py
, where 8 - number of GPUs
Code samples some train parts are in color2embed/test_parts.ipynb
Target:
Good samples:
Bad samples:
Download weights from https://drive.google.com/file/d/1xmn-8FvKqm6MoSVYYQq9Rn-BdnOTdrwx/view?usp=sharing and put it in trained_model
folder
See color2embed/inference.ipynb
for details
- Color2Embed: Fast Exemplar-Based Image Colorization using Color Embeddings: https://arxiv.org/abs/2106.08017
- TSP augmentation: https://github.com/cheind/py-thin-plate-spline
- UNet parts: https://github.com/milesial/Pytorch-UNet
- Modulated Convolution: https://github.com/rosinality/stylegan2-pytorch