diff --git a/.DS_Store b/.DS_Store
deleted file mode 100644
index 856e88d..0000000
Binary files a/.DS_Store and /dev/null differ
diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000..3d436b6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,20 @@
+__pycache__
+data/*
+exp/demo/*
+openai
+*.o
+*.so
+*.pyx
+core/lib/freqencoder/build
+core/lib/gridencoder/build
+core/lib/freqencoder/dist
+core/lib/gridencoder/dist
+*.egg-info
+logs
+*.pt
+*.pth
+thirdparties/MODNet
+thirdparties/clip
+clip_ckpts
+input_data/demo/*
+.DS_Store
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
index 51c9a49..2b454fc 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2023 Yangyi Huang
+Copyright (c) 2023 Yangyi Huang, Hongwei Yi, Yuliang Xiu, Tingting Liao, Jiaxiang Tang, Deng Cai, Justus Thies
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/README.md b/README.md
index e328088..9883125 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,30 @@
TeCH considers image-based reconstruction as a conditional generation task, taking conditions from both the input image and the derived descriptions. It is capable of reconstructing "lifelike" 3D clothed humans. “Lifelike” refers to 1) a detailed full-body geometry, including facial features and clothing wrinkles, in both frontal and unseen regions, and 2) a high-quality texture with consistent color and intricate patterns.
+## Installation
+
+Please follow the [Installation Instruction](docs/install.md) to setup all the required packages.
+
+## Getting Started
+
+We provide a running script at `scripts/run.sh`. Before getting started, you need to set your own environment variables of `CUDA_HOME` and `REPLICATE_API_TOKEN`([get your token here](https://replicate.com/signin?next=/account/api-tokens)) in the script.
+
+After that, you can use TeCH to create a highly detailed clothed human textured mesh from a single image, for example:
+
+```shell
+sh scripts/run.sh input/examples/name.img exp/examples/name
+```
+
+The results will be save in the experiment folder `exp/examples/name`, and the textured mesh will be saved as `exp/examples/name/obj/name_texture.obj`
+
+Noted that in the "Step 3", the current version of Dreambooth implementation requires 2\*32G GPU memory. And 1\*32G GPU memory is efficient for other steps. The entire training process for a subject takes ~3 hours on our V100 GPUs.
+
+## TODOs
+
+- [ ] Release of evaluation protocals and results data for comparison (on CAPE & THUman 2.0 datasets).
+- [ ] Try to use the diffusers version of DreamBooth to save training memory.
+- [ ] Further improvement of efficiency and robustness.
+
## Citation
```bibtex
@@ -50,3 +74,9 @@ TeCH considers image-based reconstruction as a conditional generation task, taki
year={2024}
}
```
+## License
+This code and model are available for non-commercial scientific research purposes as defined in the LICENSE (i.e., MIT LICENSE).
+Note that, using TeCH, you have to register SMPL-X and agree with the LICENSE of it, and it's not MIT LICENSE, you can check the LICENSE of SMPL-X from https://github.com/vchoutas/smplx/blob/main/LICENSE.
+
+## Acknowledgment
+This implementation is mainly built based on [Stable Dreamfusion](https://github.com/ashawkey/stable-dreamfusion), [ECON](https://github.com/YuliangXiu/ECON) [DreamBooth-Stable-Diffusion](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion), and the BLIP API from Salesforce on [Replicate](https://replicate.com/salesforce/blip)
\ No newline at end of file
diff --git a/configs/default.yaml b/configs/default.yaml
new file mode 100755
index 0000000..5192b07
--- /dev/null
+++ b/configs/default.yaml
@@ -0,0 +1,160 @@
+workspace: null
+exp_root: null
+stage: null
+use_gl: False
+profile: False
+fp16: False
+
+model:
+ use_dmtet_network: False
+ use_explicit_tet: False
+ use_color_network: false
+ tet_shell_offset: 0.1
+ tet_shell_decimate: 0.9
+ tet_offset_scale: 0.
+ tet_grid_scale: 0.
+ tet_grid_volume: 0.00000005
+ tet_num_subdiv: 0
+ dmtet_network: hash
+ render_ssaa: 4
+ use_texture_2d: false
+ use_vertex_tex: False
+ mesh_scale: 1.0
+ albedo_res: 2048
+ different_bg: false
+ single_bg_color: False
+ use_can_pose_space: False
+ geo_hash_max_res: 1024
+ geo_hash_num_levels: 16
+ geo_hash_max_res: 1024
+ color_hash_num_levels: 16
+ color_hash_max_res: 2048
+ color_num_layers: 1
+ color_hidden_dim: 32
+ min_near: 0.01
+
+
+train:
+ dmtet_lr: 0.1
+ init_texture_3d: False
+ init_mesh: True
+ init_mesh_padding: 0.
+ tet_subdiv_steps: null
+ workspace: null
+ eval_interval: 10
+ lock_geo: False
+ fp16: False
+ render_ssaa: 4
+ w: 512
+ h: 512
+
+ iters: 0
+ lr: 0.001
+ warm_iters: 0
+ min_lr: 0
+
+ ckpt: latest
+ pretrained: null
+
+ optim: adan
+
+ render_relative_normal: true
+ albedo_sample_ratio: 1.0
+ normal_sample_ratio: 0.
+ textureless_sample_ratio: 0.
+ can_pose_sample_ratio: 0.
+ train_both: false
+
+ loss_mask_erosion: 10
+
+ lambda_normal: 0.
+ lambda_lap: 0.
+ lambda_recon: 0.
+ lambda_sil: 0.
+ lambda_color_chamfer: 0.
+
+ crop_for_lpips: false
+ use_lap_loss: false
+ single_directional_color_chamfer: False
+ color_chamfer_step: 0
+ color_chamfer_space: rgb
+
+ decay_lnorm_cosine_cycle: null
+ decay_lnorm_cosine_max_iter: null
+ decay_lnorm_iter: null
+ decay_lnorm_ratio: null
+
+ jitter_pose: False
+ radius_range: [0.7, 1.3]
+ height_range: [-0.4, 0.4]
+ fovy_range: [40, 70]
+ theta_range: [60, 120]
+ phi_range: [0., 360.]
+ phi_diff: 30
+ angle_front: 60
+ angle_overhead: 30
+ face_sample_ratio: 0.3
+ face_height_range: [0., 0.]
+ face_radius_range: [0.3, 0.4]
+ face_phi_diff: 30
+ face_theta_range: [90, 90]
+ face_phi_range: [-90, 90]
+
+ init_empty_tex: False
+
+data:
+ load_input_image: True
+ img: null
+ load_front_normal: false
+ front_normal_img: null
+ load_back_normal: false
+ back_normal_img: null
+ load_keypoints: True
+ keypoints_path: null
+ load_result_mesh: False
+ last_model: null
+ last_ref_model: null
+ smpl_model: null
+ load_apose_mesh: False
+ can_pose_folder: null
+ load_occ_mask: False
+ occ_mask: null
+ loss_mask: null
+ load_da_pose_mesh: False
+ da_pose_mesh: null
+
+guidance:
+ type: stable-diffusion
+ use_view_prompt: True
+ sd_version: 1.5
+ guidance_scale: 100.
+ step_range: [0.02, 0.25]
+ use_dreambooth: True
+ hf_key: null
+ head_hf_key: null
+ lora: null
+ text: null
+ text_geo: null
+ text_head: null
+ text_extra: ''
+ normal_text: null
+ normal_text_extra: ''
+ textureless_text: null
+ textureless_text_extra: ''
+ negative: ''
+ negative_normal: ''
+ negative_textureless: ''
+ controlnet: null
+ controlnet_guidance_geometry: null
+ controlnet_conditioning_scale: 0.
+ controlnet_openpose_guidance: null
+
+test:
+ test: false
+ not_test_video: False
+ save_mesh: True
+ save_uv: False
+ write_image: False
+ W: 800
+ H: 800
+
\ No newline at end of file
diff --git a/configs/tech_geometry.yaml b/configs/tech_geometry.yaml
new file mode 100755
index 0000000..93c9cdc
--- /dev/null
+++ b/configs/tech_geometry.yaml
@@ -0,0 +1,41 @@
+exp_root: null
+stage: geometry
+model:
+ use_dmtet_network: True
+ tet_offset_scale: 0.
+ tet_grid_volume: 5e-8
+ tet_num_subdiv: 1
+ render_ssaa: 4
+train:
+ iters: 10000
+ tet_subdiv_steps: [5000]
+ use_lap_loss: True
+ normal_sample_ratio: 1.0
+ radius_range: [0.7, 1.3]
+ height_range: [-0.4, 0.4]
+ theta_range: [60, 120]
+ phi_diff: 30
+ face_sample_ratio: 0.3
+ face_height_range: [0., 0.]
+ face_radius_range: [0.3, 0.4]
+ face_phi_diff: 30
+ face_theta_range: [90, 90]
+ face_phi_range: [-90, 90]
+ render_relative_normal: True
+ lambda_lap: 1e4
+ lambda_sil: 1e4
+ lambda_normal: 1e4
+ lambda_recon: 0.
+ lambda_color_chamfer: 0.
+ decay_lnorm_cosine_cycle: 5000
+ decay_lnorm_cosine_max_iter: 10000
+
+data:
+ load_input_image: True
+ load_front_normal: True
+ load_back_normal: True
+guidance:
+ normal_text: "a smooth and detailed sculpture of"
+ use_view_prompt: True
+ guidance_scale: 100.
+ step_range: [0.02, 0.25]
diff --git a/configs/tech_texture.yaml b/configs/tech_texture.yaml
new file mode 100755
index 0000000..6d70bd9
--- /dev/null
+++ b/configs/tech_texture.yaml
@@ -0,0 +1,44 @@
+exp_root: ''
+stage: texture
+model:
+ use_dmtet_network: false
+ use_color_network: true
+ tet_offset_scale: 0.
+ tet_grid_volume: 5e-8
+ tet_num_subdiv: 1
+ render_ssaa: 4
+ use_can_pose_space: True
+train:
+ lock_geo: True
+ iters: 7000
+ normal_sample_ratio: 0.
+ radius_range: [0.7, 1.3]
+ height_range: [-0.4, 0.4]
+ theta_range: [60, 120]
+ phi_diff: 30
+ face_sample_ratio: 0.3
+ face_height_range: [0., 0.]
+ face_radius_range: [0.3, 0.4]
+ face_phi_diff: 30
+ face_theta_range: [90, 90]
+ face_phi_range: [-90, 90]
+ lambda_lap: 0.
+ lambda_sil: 0.
+ lambda_normal: 0.
+ lambda_recon: 10000.
+ lambda_color_chamfer: 1e6
+ color_chamfer_step: 5000
+ crop_for_lpips: true
+test:
+ save_mesh: true
+ test: False
+data:
+ load_input_image: True
+ load_front_normal: True
+ load_back_normal: True
+ load_result_mesh: True
+ load_apose_mesh: True
+guidance:
+ use_view_prompt: True
+ guidance_scale: 100.
+ step_range: [0.02, 0.25]
diff --git a/core/lib/annotators.py b/core/lib/annotators.py
new file mode 100755
index 0000000..6fbc52e
--- /dev/null
+++ b/core/lib/annotators.py
@@ -0,0 +1,166 @@
+import numpy as np
+import cv2
+import os
+import torch
+from einops import rearrange
+
+
+class Network(torch.nn.Module):
+ def __init__(self, model_path):
+ super().__init__()
+
+ self.netVggOne = torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=3, out_channels=64,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=64, out_channels=64,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netVggTwo = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=64, out_channels=128,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=128, out_channels=128,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netVggThr = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=128, out_channels=256,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=256, out_channels=256,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=256, out_channels=256,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netVggFou = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=256, out_channels=512,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netVggFiv = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=512, out_channels=512,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512,
+ kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netScoreOne = torch.nn.Conv2d(
+ in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreTwo = torch.nn.Conv2d(
+ in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreThr = torch.nn.Conv2d(
+ in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreFou = torch.nn.Conv2d(
+ in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreFiv = torch.nn.Conv2d(
+ in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+
+ self.netCombine = torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=5, out_channels=1,
+ kernel_size=1, stride=1, padding=0),
+ torch.nn.Sigmoid()
+ )
+
+ self.load_state_dict({strKey.replace(
+ 'module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
+
+ def forward(self, tenInput):
+ tenInput = tenInput * 255.0
+ tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434],
+ dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
+
+ tenVggOne = self.netVggOne(tenInput)
+ tenVggTwo = self.netVggTwo(tenVggOne)
+ tenVggThr = self.netVggThr(tenVggTwo)
+ tenVggFou = self.netVggFou(tenVggThr)
+ tenVggFiv = self.netVggFiv(tenVggFou)
+
+ tenScoreOne = self.netScoreOne(tenVggOne)
+ tenScoreTwo = self.netScoreTwo(tenVggTwo)
+ tenScoreThr = self.netScoreThr(tenVggThr)
+ tenScoreFou = self.netScoreFou(tenVggFou)
+ tenScoreFiv = self.netScoreFiv(tenVggFiv)
+
+ tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(
+ tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(
+ tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(
+ tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(
+ tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(
+ tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+
+ return self.netCombine(torch.cat([tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv], 1))
+
+
+class Cannydetector:
+ def __init__(self, low_threshold, high_threshold):
+ self.low_thres = low_threshold
+ self.high_thres = high_threshold
+
+ def __call__(self, image_canny):
+ img = (image_canny.cpu().numpy() * 255).astype(np.uint8)
+ img = cv2.resize(img, (512,512))
+ return cv2.Canny(img, self.low_thres, self.high_thres)
+
+
+class HEDdetector:
+ def __init__(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
+ annotator_ckpts_path = 'ckpts'
+ modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path,
+ model_dir=annotator_ckpts_path)
+ self.netNetwork = Network(modelpath).cuda().eval()
+
+ def __call__(self, image_hed):
+ with torch.no_grad():
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+ edge = self.netNetwork(image_hed)[0]
+ edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
+ return edge[0]
+
+
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+ y = np.zeros_like(x)
+
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
diff --git a/core/lib/camera_utils.py b/core/lib/camera_utils.py
new file mode 100755
index 0000000..883c902
--- /dev/null
+++ b/core/lib/camera_utils.py
@@ -0,0 +1,93 @@
+import torch
+import numpy as np
+from packaging import version as pver
+
+
+def custom_meshgrid(*args):
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
+ return torch.meshgrid(*args)
+ else:
+ return torch.meshgrid(*args, indexing='ij')
+
+def safe_normalize(x, eps=1e-20):
+ return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
+
+
+
+@torch.cuda.amp.autocast(enabled=False)
+def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, coords=None):
+ ''' get rays
+ Args:
+ poses: [N/1, 4, 4], cam2world
+ intrinsics: [N/1, 4] tensor or [4] ndarray
+ H, W, N: int
+ Returns:
+ rays_o, rays_d: [N, 3]
+ i, j: [N]
+ '''
+
+ device = poses.device
+
+ if isinstance(intrinsics, np.ndarray):
+ fx, fy, cx, cy = intrinsics
+ else:
+ fx, fy, cx, cy = intrinsics[:, 0], intrinsics[:, 1], intrinsics[:, 2], intrinsics[:, 3]
+
+ i, j = custom_meshgrid(torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H,
+ device=device)) # float
+ i = i.t().contiguous().view(-1) + 0.5
+ j = j.t().contiguous().view(-1) + 0.5
+
+ results = {}
+
+ if N > 0:
+
+ if coords is not None:
+ inds = coords[:, 0] * W + coords[:, 1]
+
+ elif patch_size > 1:
+
+ # random sample left-top cores.
+ num_patch = N // (patch_size**2)
+ inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)
+ inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)
+ inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]
+
+ # create meshgrid for each patch
+ pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device))
+ offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2]
+
+ inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2]
+ inds = inds.view(-1, 2) # [N, 2]
+ inds = inds[:, 0] * W + inds[:, 1] # [N], flatten
+
+ else: # random sampling
+ inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate
+
+ i = torch.gather(i, -1, inds)
+ j = torch.gather(j, -1, inds)
+
+ results['i'] = i.long()
+ results['j'] = j.long()
+
+ else:
+ inds = torch.arange(H * W, device=device)
+
+ zs = -torch.ones_like(i) # z is flipped
+ xs = (i - cx) / fx
+ ys = -(j - cy) / fy
+ directions = torch.stack((xs, ys, zs), dim=-1) # [N, 3]
+ # do not normalize to get actual depth, ref: https://github.com/dunbar12138/DSNeRF/issues/29
+ # directions = directions / torch.norm(directions, dim=-1, keepdim=True)
+ rays_d = (directions.unsqueeze(1) @ poses[:, :3, :3].transpose(-1, -2)).squeeze(
+ 1) # [N, 1, 3] @ [N, 3, 3] --> [N, 1, 3]
+
+ rays_o = poses[:, :3, 3].expand_as(rays_d) # [N, 3]
+
+ results['rays_o'] = rays_o
+ results['rays_d'] = rays_d
+
+ # visualize_rays(rays_o[0].detach().cpu().numpy(), rays_d[0].detach().cpu().numpy())
+
+ return results
\ No newline at end of file
diff --git a/core/lib/chamfer.py b/core/lib/chamfer.py
new file mode 100755
index 0000000..8ec828e
--- /dev/null
+++ b/core/lib/chamfer.py
@@ -0,0 +1,254 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Union
+
+import torch
+import torch.nn.functional as F
+from pytorch3d.ops.knn import knn_gather, knn_points
+from pytorch3d.structures.pointclouds import Pointclouds
+
+
+def _validate_chamfer_reduction_inputs(
+ batch_reduction: Union[str, None], point_reduction: str
+) -> None:
+ """Check the requested reductions are valid.
+
+ Args:
+ batch_reduction: Reduction operation to apply for the loss across the
+ batch, can be one of ["mean", "sum"] or None.
+ point_reduction: Reduction operation to apply for the loss across the
+ points, can be one of ["mean", "sum"].
+ """
+ if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
+ raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
+ if point_reduction not in ["mean", "sum"]:
+ raise ValueError('point_reduction must be one of ["mean", "sum"]')
+
+
+def _handle_pointcloud_input(
+ points: Union[torch.Tensor, Pointclouds],
+ lengths: Union[torch.Tensor, None],
+ normals: Union[torch.Tensor, None],
+):
+ """
+ If points is an instance of Pointclouds, retrieve the padded points tensor
+ along with the number of points per batch and the padded normals.
+ Otherwise, return the input points (and normals) with the number of points per cloud
+ set to the size of the second dimension of `points`.
+ """
+ if isinstance(points, Pointclouds):
+ X = points.points_padded()
+ lengths = points.num_points_per_cloud()
+ normals = points.normals_padded() # either a tensor or None
+ elif torch.is_tensor(points):
+ if points.ndim != 3:
+ raise ValueError("Expected points to be of shape (N, P, D)")
+ X = points
+ if lengths is not None:
+ if lengths.ndim != 1 or lengths.shape[0] != X.shape[0]:
+ raise ValueError("Expected lengths to be of shape (N,)")
+ if lengths.max() > X.shape[1]:
+ raise ValueError("A length value was too long")
+ if lengths is None:
+ lengths = torch.full(
+ (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
+ )
+ if normals is not None and normals.ndim != 3:
+ raise ValueError("Expected normals to be of shape (N, P, 3")
+ else:
+ raise ValueError(
+ "The input pointclouds should be either "
+ + "Pointclouds objects or torch.Tensor of shape "
+ + "(minibatch, num_points, 3)."
+ )
+ return X, lengths, normals
+
+
+def _chamfer_distance_single_direction(
+ x,
+ y,
+ x_lengths,
+ y_lengths,
+ x_normals,
+ y_normals,
+ weights,
+ batch_reduction: Union[str, None],
+ point_reduction: str,
+ norm: int,
+ abs_cosine: bool,
+):
+ return_normals = x_normals is not None and y_normals is not None
+
+ N, P1, D = x.shape
+
+ # Check if inputs are heterogeneous and create a lengths mask.
+ is_x_heterogeneous = (x_lengths != P1).any()
+ x_mask = (
+ torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
+ ) # shape [N, P1]
+ if y.shape[0] != N or y.shape[2] != D:
+ raise ValueError("y does not have the correct shape.")
+ if weights is not None:
+ if weights.size(0) != N:
+ raise ValueError("weights must be of shape (N,).")
+ if not (weights >= 0).all():
+ raise ValueError("weights cannot be negative.")
+ if weights.sum() == 0.0:
+ weights = weights.view(N, 1)
+ if batch_reduction in ["mean", "sum"]:
+ return (
+ (x.sum((1, 2)) * weights).sum() * 0.0,
+ (x.sum((1, 2)) * weights).sum() * 0.0,
+ )
+ return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
+
+ cham_norm_x = x.new_zeros(())
+
+ x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
+ cham_x = x_nn.dists[..., 0] # (N, P1)
+
+ if is_x_heterogeneous:
+ cham_x[x_mask] = 0.0
+
+ if weights is not None:
+ cham_x *= weights.view(N, 1)
+
+ if return_normals:
+ # Gather the normals using the indices and keep only value for k=0
+ x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
+
+ cosine_sim = F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
+ # If abs_cosine, ignore orientation and take the absolute value of the cosine sim.
+ cham_norm_x = 1 - (torch.abs(cosine_sim) if abs_cosine else cosine_sim)
+
+ if is_x_heterogeneous:
+ cham_norm_x[x_mask] = 0.0
+
+ if weights is not None:
+ cham_norm_x *= weights.view(N, 1)
+ cham_norm_x = cham_norm_x.sum(1) # (N,)
+
+ # Apply point reduction
+ cham_x = cham_x.sum(1) # (N,)
+ if point_reduction == "mean":
+ x_lengths_clamped = x_lengths.clamp(min=1)
+ cham_x /= x_lengths_clamped
+ if return_normals:
+ cham_norm_x /= x_lengths_clamped
+
+ if batch_reduction is not None:
+ # batch_reduction == "sum"
+ cham_x = cham_x.sum()
+ if return_normals:
+ cham_norm_x = cham_norm_x.sum()
+ if batch_reduction == "mean":
+ div = weights.sum() if weights is not None else max(N, 1)
+ cham_x /= div
+ if return_normals:
+ cham_norm_x /= div
+
+ cham_dist = cham_x
+ cham_normals = cham_norm_x if return_normals else None
+ return cham_dist, cham_normals
+
+
+def chamfer_distance(
+ x,
+ y,
+ x_lengths=None,
+ y_lengths=None,
+ x_normals=None,
+ y_normals=None,
+ weights=None,
+ batch_reduction: Union[str, None] = "mean",
+ point_reduction: str = "mean",
+ norm: int = 2,
+ single_directional: bool = False,
+ abs_cosine: bool = True,
+):
+ """
+ Chamfer distance between two pointclouds x and y.
+
+ Args:
+ x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
+ a batch of point clouds with at most P1 points in each batch element,
+ batch size N and feature dimension D.
+ y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
+ a batch of point clouds with at most P2 points in each batch element,
+ batch size N and feature dimension D.
+ x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
+ cloud in x.
+ y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
+ cloud in y.
+ x_normals: Optional FloatTensor of shape (N, P1, D).
+ y_normals: Optional FloatTensor of shape (N, P2, D).
+ weights: Optional FloatTensor of shape (N,) giving weights for
+ batch elements for reduction operation.
+ batch_reduction: Reduction operation to apply for the loss across the
+ batch, can be one of ["mean", "sum"] or None.
+ point_reduction: Reduction operation to apply for the loss across the
+ points, can be one of ["mean", "sum"].
+ norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
+ single_directional: If False (default), loss comes from both the distance between
+ each point in x and its nearest neighbor in y and each point in y and its nearest
+ neighbor in x. If True, loss is the distance between each point in x and its
+ nearest neighbor in y.
+ abs_cosine: If False, loss_normals is from one minus the cosine similarity.
+ If True (default), loss_normals is from one minus the absolute value of the
+ cosine similarity, which means that exactly opposite normals are considered
+ equivalent to exactly matching normals, i.e. sign does not matter.
+
+ Returns:
+ 2-element tuple containing
+
+ - **loss**: Tensor giving the reduced distance between the pointclouds
+ in x and the pointclouds in y.
+ - **loss_normals**: Tensor giving the reduced cosine distance of normals
+ between pointclouds in x and pointclouds in y. Returns None if
+ x_normals and y_normals are None.
+
+ """
+ _validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
+
+ if not ((norm == 1) or (norm == 2)):
+ raise ValueError("Support for 1 or 2 norm.")
+ x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
+ y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
+
+ cham_x, cham_norm_x = _chamfer_distance_single_direction(
+ x,
+ y,
+ x_lengths,
+ y_lengths,
+ x_normals,
+ y_normals,
+ weights,
+ batch_reduction,
+ point_reduction,
+ norm,
+ abs_cosine,
+ )
+ if single_directional:
+ return cham_x, cham_norm_x
+ else:
+ cham_y, cham_norm_y = _chamfer_distance_single_direction(
+ y,
+ x,
+ y_lengths,
+ x_lengths,
+ y_normals,
+ x_normals,
+ weights,
+ batch_reduction,
+ point_reduction,
+ norm,
+ abs_cosine,
+ )
+ return (
+ cham_x + cham_y,
+ (cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
+ )
diff --git a/core/lib/color_network.py b/core/lib/color_network.py
new file mode 100755
index 0000000..c239ba5
--- /dev/null
+++ b/core/lib/color_network.py
@@ -0,0 +1,24 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .network_utils import Decoder, HashDecoder
+
+class ColorNetwork(nn.Module):
+ def __init__(
+ self,
+ cfg,
+ num_layers=1,
+ hidden_dim=32,
+ hash_max_res=2048,
+ hash_num_levels=16
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.hidden_dim = hidden_dim
+
+ self.net = HashDecoder(3, self.hidden_dim, 3, self.num_layers, max_res=hash_max_res, num_levels=hash_num_levels)
+
+ def forward(self, x):
+ albedo = torch.sigmoid(self.net(x))
+ return albedo
\ No newline at end of file
diff --git a/core/lib/color_utils.py b/core/lib/color_utils.py
new file mode 100755
index 0000000..87740b9
--- /dev/null
+++ b/core/lib/color_utils.py
@@ -0,0 +1,48 @@
+import torch
+
+def rgb2xyz(var, device = 'cuda'):
+ #input (min, max) = (0, 1)
+ #output (min, max) = (0, 1)
+ transform = torch.FloatTensor([[0.412453, 0.357580, 0.180423],
+ [0.212671, 0.715160, 0.072169],
+ [ 0.019334, 0.119193, 0.950227]]).to(device)
+ xyz = torch.matmul(var, transform.t())
+ return xyz
+
+def rgb2ycrcb(imgs):
+ #input (min, max) = (0, 1)
+ #output (min, max) = (0, 1)
+ r = imgs[..., 0] * 255
+ g = imgs[..., 1] * 255
+ b = imgs[..., 2] * 255
+ y = 0.299*r + 0.587*g + 0.114*b
+ cr = (r - y)*0.713 + 128
+ cb = (b - y)*0.564 + 128
+ ycrcb = torch.stack([y, cb, cr], -1)
+ return (ycrcb - 16) / (240 - 16)
+
+def rgb2srgb(imgs):
+ return torch.where(imgs <= 0.04045, imgs/12.92, torch.pow((imgs + 0.055)/1.055, 2.4))
+
+def rgb2cmyk(imgs, device='cuda'):
+ r = imgs[..., 0]
+ g = imgs[..., 1]
+ b = imgs[..., 2]
+ k = 1 - torch.max(imgs, dim=-1).values
+ c = (1-r-k)/(1-k + 1e-7)
+ m = (1-g-k)/(1-k + 1e-7)
+ y = (1-b-k)/(1-k + 1e-7)
+ result = torch.stack([c, m, y, k], -1).clamp(0, 1)
+ return result
+
+def convert_rgb(imgs, target='rgb'):
+ if target == 'rgb':
+ return imgs
+ elif target == 'cmyk':
+ return rgb2cmyk(imgs)
+ elif target == 'xyz':
+ return rgb2xyz(imgs)
+ elif target == 'ycrcb':
+ return rgb2ycrcb(imgs)
+ elif target == 'srgb':
+ return rgb2srgb(imgs)
\ No newline at end of file
diff --git a/core/lib/dmtet_network.py b/core/lib/dmtet_network.py
new file mode 100755
index 0000000..118c203
--- /dev/null
+++ b/core/lib/dmtet_network.py
@@ -0,0 +1,200 @@
+import torch
+import torch.nn as nn
+import kaolin as kal
+from tqdm import tqdm
+import random
+import trimesh
+from .network_utils import Decoder, HashDecoder, HashDecoderNew
+# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
+# https://mgarland.org/class/geom04/material/smoothing.pdf
+def laplace_regularizer_const(mesh_verts, mesh_faces):
+ term = torch.zeros_like(mesh_verts)
+ norm = torch.zeros_like(mesh_verts[..., 0:1])
+
+ v0 = mesh_verts[mesh_faces[:, 0], :]
+ v1 = mesh_verts[mesh_faces[:, 1], :]
+ v2 = mesh_verts[mesh_faces[:, 2], :]
+
+ term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
+ term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
+ term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
+
+ two = torch.ones_like(v0) * 2.0
+ norm.scatter_add_(0, mesh_faces[:, 0:1], two)
+ norm.scatter_add_(0, mesh_faces[:, 1:2], two)
+ norm.scatter_add_(0, mesh_faces[:, 2:3], two)
+
+ term = term / torch.clamp(norm, min=1.0)
+
+ return torch.mean(term**2)
+
+def loss_f(mesh_verts, mesh_faces, points, it):
+ pred_points = kal.ops.mesh.sample_points(mesh_verts.unsqueeze(0), mesh_faces, 50000)[0][0]
+ chamfer = kal.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), points.unsqueeze(0)).mean()
+ laplacian_weight = 0.1
+ if it > iterations//2:
+ lap = laplace_regularizer_const(mesh_verts, mesh_faces)
+ return chamfer + lap * laplacian_weight
+ return chamfer
+
+###############################################################################
+# Compact tet grid
+###############################################################################
+
+def compact_tets(pos_nx3, sdf_n, tet_fx4):
+ with torch.no_grad():
+ # Find surface tets
+ occ_n = sdf_n > 0
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
+ occ_sum = torch.sum(occ_fx4, -1)
+ valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets
+
+ valid_vtx = tet_fx4[valid_tets].reshape(-1)
+ unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True)
+ new_pos = pos_nx3[unique_vtx]
+ new_sdf = sdf_n[unique_vtx]
+ new_tets = idx_map.reshape(-1, 4)
+ return new_pos, new_sdf, new_tets
+
+
+###############################################################################
+# Subdivide volume
+###############################################################################
+
+def sort_edges(edges_ex2):
+ with torch.no_grad():
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
+ order = order.unsqueeze(dim=1)
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
+ return torch.stack([a, b], -1)
+
+
+def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4):
+ device = tet_pos_bxnx3.device
+ # get new verts
+ tet_fx4 = tet_bxfx4[0]
+ edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3]
+ all_edges = tet_fx4[:, edges].reshape(-1, 2)
+ all_edges = sort_edges(all_edges)
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
+ idx_map = idx_map + tet_pos_bxnx3.shape[1]
+ all_values = tet_pos_bxnx3
+ mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape(
+ all_values.shape[0], -1, 2,
+ all_values.shape[-1]).mean(2)
+ new_v = torch.cat([all_values, mid_points_pos], 1)
+
+ # get new tets
+
+ idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3]
+ idx_ab = idx_map[0::6]
+ idx_ac = idx_map[1::6]
+ idx_ad = idx_map[2::6]
+ idx_bc = idx_map[3::6]
+ idx_bd = idx_map[4::6]
+ idx_cd = idx_map[5::6]
+
+ tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1)
+ tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1)
+ tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1)
+ tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1)
+ tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1)
+ tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1)
+ tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1)
+ tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1)
+
+ tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0)
+ tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1)
+ tet = tet_np.long().to(device)
+
+ return new_v, tet
+
+
+class DMTetMesh(nn.Module):
+ def __init__(self, vertices: torch.Tensor, indices: torch.Tensor, device: str='cuda', grid_scale=1e-4, use_explicit=False, geo_network='mlp', hash_max_res=1024, hash_num_levels=16, num_subdiv=0) -> None:
+ super().__init__()
+ self.device = device
+ self.tet_v = vertices.to(device)
+ self.tet_ind = indices.to(device)
+ self.use_explicit = use_explicit
+ if self.use_explicit:
+ self.sdf = nn.Parameter(torch.zeros_like(self.tet_v[:, 0]), requires_grad=True)
+ self.deform = nn.Parameter(torch.zeros_like(self.tet_v), requires_grad=True)
+ elif geo_network == 'mlp':
+ self.decoder = Decoder().to(device)
+ elif geo_network == 'hash':
+ pts_bounds = (self.tet_v.min(dim=0)[0], self.tet_v.max(dim=0)[0])
+ self.decoder = HashDecoder(input_bounds=pts_bounds, max_res=hash_max_res, num_levels=hash_num_levels).to(device)
+ self.grid_scale = grid_scale
+ self.num_subdiv = num_subdiv
+
+ def query_decoder(self, tet_v):
+ if self.tet_v.shape[0] < 1000000:
+ return self.decoder(tet_v)
+ else:
+ chunk_size = 1000000
+ results = []
+ for i in range((tet_v.shape[0] // chunk_size) + 1):
+ if i*chunk_size < tet_v.shape[0]:
+ results.append(self.decoder(tet_v[i*chunk_size: (i+1)*chunk_size]))
+ return torch.cat(results, dim=0)
+
+ def get_mesh(self, return_loss=False, num_subdiv=None):
+ if num_subdiv is None:
+ num_subdiv = self.num_subdiv
+ if self.use_explicit:
+ sdf = self.sdf * 1
+ deform = self.deform * 1
+ else:
+ pred = self.query_decoder(self.tet_v)
+ sdf, deform = pred[:,0], pred[:,1:]
+ verts_deformed = self.tet_v + torch.tanh(deform) * self.grid_scale / 2 # constraint deformation to avoid flipping tets
+ tet = self.tet_ind
+ for i in range(num_subdiv):
+ verts_deformed, _, tet = compact_tets(verts_deformed, sdf, tet)
+ verts_deformed, tet = batch_subdivide_volume(verts_deformed.unsqueeze(0), tet.unsqueeze(0))
+ verts_deformed = verts_deformed[0]
+ tet = tet[0]
+ pred = self.query_decoder(verts_deformed)
+ sdf, _ = pred[:,0], pred[:,1:]
+ mesh_verts, mesh_faces = kal.ops.conversions.marching_tetrahedra(verts_deformed.unsqueeze(0), tet, sdf.unsqueeze(0)) # running MT (batched) to extract surface mesh
+
+ mesh_verts, mesh_faces = mesh_verts[0], mesh_faces[0]
+ return mesh_verts, mesh_faces, None
+
+ def init_mesh(self, mesh_v, mesh_f, init_padding=0.):
+ num_pts = self.tet_v.shape[0]
+ mesh = trimesh.Trimesh(mesh_v.cpu().numpy(), mesh_f.cpu().numpy())
+ import mesh_to_sdf
+ sdf_tet = torch.tensor(mesh_to_sdf.mesh_to_sdf(mesh, self.tet_v.cpu().numpy()), dtype=torch.float32).to(self.device) - init_padding
+ sdf_mesh_v, sdf_mesh_f = kal.ops.conversions.marching_tetrahedra(self.tet_v.unsqueeze(0), self.tet_ind, sdf_tet.unsqueeze(0))
+ sdf_mesh_v, sdf_mesh_f = sdf_mesh_v[0], sdf_mesh_f[0]
+ if self.use_explicit:
+ self.sdf.data[...] = sdf_tet[...]
+ else:
+ optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-3)
+ batch_size = 300000
+ iter = 1000
+ points, sdf_gt = mesh_to_sdf.sample_sdf_near_surface(mesh)
+ valid_idx = (points < self.tet_v.cpu().numpy().min(axis=0)).sum(-1) + (points > self.tet_v.cpu().numpy().max(axis=0)).sum(-1) == 0
+ points = points[valid_idx]
+ sdf_gt = sdf_gt[valid_idx]
+ points = torch.tensor(points, dtype=torch.float32).to(self.device)
+ sdf_gt = torch.tensor(sdf_gt, dtype=torch.float32).to(self.device)
+ points = torch.cat([points, self.tet_v], dim=0)
+ sdf_gt = torch.cat([sdf_gt, sdf_tet], dim=0)
+ num_pts = len(points)
+ for i in tqdm(range(iter)):
+ sampled_ind = random.sample(range(num_pts), min(batch_size, num_pts))
+ p = points[sampled_ind]
+ pred = self.decoder(p)
+ sdf, deform = pred[:,0], pred[:,1:]
+ loss = nn.functional.mse_loss(sdf, sdf_gt[sampled_ind])# + (deform ** 2).mean()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ with torch.no_grad():
+ mesh_v, mesh_f, _ = self.get_mesh(return_loss=False)
+ pred_mesh = trimesh.Trimesh(mesh_v.cpu().numpy(), mesh_f.cpu().numpy())
+ print('fitted mesh with num_vertex {}, num_faces {}'.format(mesh_v.shape[0], mesh_f.shape[0]))
\ No newline at end of file
diff --git a/core/lib/encoding.py b/core/lib/encoding.py
new file mode 100755
index 0000000..421c3dc
--- /dev/null
+++ b/core/lib/encoding.py
@@ -0,0 +1,77 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class FreqEncoder_torch(nn.Module):
+ def __init__(self, input_dim, max_freq_log2, N_freqs,
+ log_sampling=True, include_input=True,
+ periodic_fns=(torch.sin, torch.cos)):
+
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.include_input = include_input
+ self.periodic_fns = periodic_fns
+
+ self.output_dim = 0
+ if self.include_input:
+ self.output_dim += self.input_dim
+
+ self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
+
+ if log_sampling:
+ self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
+ else:
+ self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)
+
+ self.freq_bands = self.freq_bands.numpy().tolist()
+
+ def forward(self, input, **kwargs):
+
+ out = []
+ if self.include_input:
+ out.append(input)
+
+ for i in range(len(self.freq_bands)):
+ freq = self.freq_bands[i]
+ for p_fn in self.periodic_fns:
+ out.append(p_fn(input * freq))
+
+ out = torch.cat(out, dim=-1)
+
+ return out
+
+def get_encoder(encoding, input_dim=3,
+ multires=6,
+ degree=4,
+ num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear', input_bounds=None,
+ **kwargs):
+
+ if encoding == 'None':
+ return lambda x, **kwargs: x, input_dim
+
+ elif encoding == 'frequency_torch':
+ encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
+
+ elif encoding == 'frequency': # CUDA implementation, faster than torch.
+ from freqencoder import FreqEncoder
+ encoder = FreqEncoder(input_dim=input_dim, degree=multires)
+
+ elif encoding == 'sphere_harmonics':
+ from shencoder import SHEncoder
+ encoder = SHEncoder(input_dim=input_dim, degree=degree)
+
+ elif encoding == 'hashgrid':
+ from gridencoder import GridEncoder
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation)
+ # from .encoding_2 import HashEncoding
+ # encoder = HashEncoding()
+
+ elif encoding == 'tiledgrid':
+ from gridencoder import GridEncoder
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation)
+
+ else:
+ raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
+
+ return encoder, encoder.output_dim
diff --git a/core/lib/freqencoder/__init__.py b/core/lib/freqencoder/__init__.py
new file mode 100755
index 0000000..69ec49c
--- /dev/null
+++ b/core/lib/freqencoder/__init__.py
@@ -0,0 +1 @@
+from .freq import FreqEncoder
\ No newline at end of file
diff --git a/core/lib/freqencoder/backend.py b/core/lib/freqencoder/backend.py
new file mode 100755
index 0000000..3bd9131
--- /dev/null
+++ b/core/lib/freqencoder/backend.py
@@ -0,0 +1,41 @@
+import os
+from torch.utils.cpp_extension import load
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+ '-use_fast_math'
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+_backend = load(name='_freqencoder',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'freqencoder.cu',
+ 'bindings.cpp',
+ ]],
+ )
+
+__all__ = ['_backend']
\ No newline at end of file
diff --git a/core/lib/freqencoder/freq.py b/core/lib/freqencoder/freq.py
new file mode 100755
index 0000000..5cba1e6
--- /dev/null
+++ b/core/lib/freqencoder/freq.py
@@ -0,0 +1,77 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+try:
+ import _freqencoder as _backend
+except ImportError:
+ from .backend import _backend
+
+
+class _freq_encoder(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
+ def forward(ctx, inputs, degree, output_dim):
+ # inputs: [B, input_dim], float
+ # RETURN: [B, F], float
+
+ if not inputs.is_cuda: inputs = inputs.cuda()
+ inputs = inputs.contiguous()
+
+ B, input_dim = inputs.shape # batch size, coord dim
+
+ outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
+
+ _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
+
+ ctx.save_for_backward(inputs, outputs)
+ ctx.dims = [B, input_dim, degree, output_dim]
+
+ return outputs
+
+ @staticmethod
+ #@once_differentiable
+ @custom_bwd
+ def backward(ctx, grad):
+ # grad: [B, C * C]
+
+ grad = grad.contiguous()
+ inputs, outputs = ctx.saved_tensors
+ B, input_dim, degree, output_dim = ctx.dims
+
+ grad_inputs = torch.zeros_like(inputs)
+ _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
+
+ return grad_inputs, None, None
+
+
+freq_encode = _freq_encoder.apply
+
+
+class FreqEncoder(nn.Module):
+ def __init__(self, input_dim=3, degree=4):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.degree = degree
+ self.output_dim = input_dim + input_dim * 2 * degree
+
+ def __repr__(self):
+ return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
+
+ def forward(self, inputs, **kwargs):
+ # inputs: [..., input_dim]
+ # return: [..., ]
+
+ prefix_shape = list(inputs.shape[:-1])
+ inputs = inputs.reshape(-1, self.input_dim)
+
+ outputs = freq_encode(inputs, self.degree, self.output_dim)
+
+ outputs = outputs.reshape(prefix_shape + [self.output_dim])
+
+ return outputs
\ No newline at end of file
diff --git a/core/lib/freqencoder/setup.py b/core/lib/freqencoder/setup.py
new file mode 100755
index 0000000..3eb4af7
--- /dev/null
+++ b/core/lib/freqencoder/setup.py
@@ -0,0 +1,51 @@
+import os
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+ '-use_fast_math'
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+setup(
+ name='freqencoder', # package name, import this to use python API
+ ext_modules=[
+ CUDAExtension(
+ name='_freqencoder', # extension name, import this to use CUDA API
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'freqencoder.cu',
+ 'bindings.cpp',
+ ]],
+ extra_compile_args={
+ 'cxx': c_flags,
+ 'nvcc': nvcc_flags,
+ }
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
\ No newline at end of file
diff --git a/core/lib/freqencoder/src/bindings.cpp b/core/lib/freqencoder/src/bindings.cpp
new file mode 100755
index 0000000..bb5f285
--- /dev/null
+++ b/core/lib/freqencoder/src/bindings.cpp
@@ -0,0 +1,8 @@
+#include
+
+#include "freqencoder.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
+ m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
+}
\ No newline at end of file
diff --git a/core/lib/freqencoder/src/freqencoder.cu b/core/lib/freqencoder/src/freqencoder.cu
new file mode 100755
index 0000000..072da74
--- /dev/null
+++ b/core/lib/freqencoder/src/freqencoder.cu
@@ -0,0 +1,129 @@
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+inline constexpr __device__ float PI() { return 3.141592653589793f; }
+
+template
+__host__ __device__ T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+// inputs: [B, D]
+// outputs: [B, C], C = D + D * deg * 2
+__global__ void kernel_freq(
+ const float * __restrict__ inputs,
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
+ float * outputs
+) {
+ // parallel on per-element
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * C) return;
+
+ // get index
+ const uint32_t b = t / C;
+ const uint32_t c = t - b * C; // t % C;
+
+ // locate
+ inputs += b * D;
+ outputs += t;
+
+ // write self
+ if (c < D) {
+ outputs[0] = inputs[c];
+ // write freq
+ } else {
+ const uint32_t col = c / D - 1;
+ const uint32_t d = c % D;
+ const uint32_t freq = col / 2;
+ const float phase_shift = (col % 2) * (PI() / 2);
+ outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
+ }
+}
+
+// grad: [B, C], C = D + D * deg * 2
+// outputs: [B, C]
+// grad_inputs: [B, D]
+__global__ void kernel_freq_backward(
+ const float * __restrict__ grad,
+ const float * __restrict__ outputs,
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
+ float * grad_inputs
+) {
+ // parallel on per-element
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * D) return;
+
+ const uint32_t b = t / D;
+ const uint32_t d = t - b * D; // t % D;
+
+ // locate
+ grad += b * C;
+ outputs += b * C;
+ grad_inputs += t;
+
+ // register
+ float result = grad[d];
+ grad += D;
+ outputs += D;
+
+ for (uint32_t f = 0; f < deg; f++) {
+ result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
+ grad += 2 * D;
+ outputs += 2 * D;
+ }
+
+ // write
+ grad_inputs[0] = result;
+}
+
+
+void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(outputs);
+
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(outputs);
+
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(outputs);
+
+ static constexpr uint32_t N_THREADS = 128;
+
+ kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr());
+}
+
+
+void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
+ CHECK_CUDA(grad);
+ CHECK_CUDA(outputs);
+ CHECK_CUDA(grad_inputs);
+
+ CHECK_CONTIGUOUS(grad);
+ CHECK_CONTIGUOUS(outputs);
+ CHECK_CONTIGUOUS(grad_inputs);
+
+ CHECK_IS_FLOATING(grad);
+ CHECK_IS_FLOATING(outputs);
+ CHECK_IS_FLOATING(grad_inputs);
+
+ static constexpr uint32_t N_THREADS = 128;
+
+ kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr());
+}
\ No newline at end of file
diff --git a/core/lib/freqencoder/src/freqencoder.h b/core/lib/freqencoder/src/freqencoder.h
new file mode 100755
index 0000000..34f28c7
--- /dev/null
+++ b/core/lib/freqencoder/src/freqencoder.h
@@ -0,0 +1,10 @@
+# pragma once
+
+#include
+#include
+
+// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
+void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
+
+// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
+void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
\ No newline at end of file
diff --git a/core/lib/gridencoder/__init__.py b/core/lib/gridencoder/__init__.py
new file mode 100755
index 0000000..f1476ce
--- /dev/null
+++ b/core/lib/gridencoder/__init__.py
@@ -0,0 +1 @@
+from .grid import GridEncoder
\ No newline at end of file
diff --git a/core/lib/gridencoder/backend.py b/core/lib/gridencoder/backend.py
new file mode 100755
index 0000000..d99acb1
--- /dev/null
+++ b/core/lib/gridencoder/backend.py
@@ -0,0 +1,40 @@
+import os
+from torch.utils.cpp_extension import load
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+_backend = load(name='_grid_encoder',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'gridencoder.cu',
+ 'bindings.cpp',
+ ]],
+ )
+
+__all__ = ['_backend']
\ No newline at end of file
diff --git a/core/lib/gridencoder/grid.py b/core/lib/gridencoder/grid.py
new file mode 100755
index 0000000..32b8bea
--- /dev/null
+++ b/core/lib/gridencoder/grid.py
@@ -0,0 +1,185 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+try:
+ import _gridencoder as _backend
+except ImportError:
+ from .backend import _backend
+
+_gridtype_to_id = {
+ 'hash': 0,
+ 'tiled': 1,
+}
+
+_interp_to_id = {
+ 'linear': 0,
+ 'smoothstep': 1,
+}
+
+class _grid_encode(Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0):
+ # inputs: [B, D], float in [0, 1]
+ # embeddings: [sO, C], float
+ # offsets: [L + 1], int
+ # RETURN: [B, F], float
+
+ inputs = inputs.contiguous()
+
+ B, D = inputs.shape # batch size, coord dim
+ L = offsets.shape[0] - 1 # level
+ C = embeddings.shape[1] # embedding dim for each level
+ S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
+ H = base_resolution # base resolution
+
+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
+ # if C % 2 != 0, force float, since half for atomicAdd is very slow.
+ if torch.is_autocast_enabled() and C % 2 == 0:
+ embeddings = embeddings.to(torch.half)
+
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
+
+ if calc_grad_inputs:
+ dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
+ else:
+ dy_dx = None
+
+ _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation)
+
+ # permute back to [B, L * C]
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
+
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
+ ctx.dims = [B, D, C, L, S, H, gridtype, interpolation]
+ ctx.align_corners = align_corners
+
+ return outputs
+
+ @staticmethod
+ #@once_differentiable
+ @custom_bwd
+ def backward(ctx, grad):
+
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
+ B, D, C, L, S, H, gridtype, interpolation = ctx.dims
+ align_corners = ctx.align_corners
+
+ # grad: [B, L * C] --> [L, B, C]
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
+
+ grad_embeddings = torch.zeros_like(embeddings)
+
+ if dy_dx is not None:
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
+ else:
+ grad_inputs = None
+
+ _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)
+
+ if dy_dx is not None:
+ grad_inputs = grad_inputs.to(inputs.dtype)
+
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None, None
+
+
+
+grid_encode = _grid_encode.apply
+
+
+class GridEncoder(nn.Module):
+ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'):
+ super().__init__()
+
+ # the finest resolution desired at the last level, if provided, overridee per_level_scale
+ if desired_resolution is not None:
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
+
+ self.input_dim = input_dim # coord dims, 2 or 3
+ self.num_levels = num_levels # num levels, each level multiply resolution by 2
+ self.level_dim = level_dim # encode channels per level
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
+ self.log2_hashmap_size = log2_hashmap_size
+ self.base_resolution = base_resolution
+ self.output_dim = num_levels * level_dim
+ self.gridtype = gridtype
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
+ self.interpolation = interpolation
+ self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep"
+ self.align_corners = align_corners
+
+ # allocate parameters
+ offsets = []
+ offset = 0
+ self.max_params = 2 ** log2_hashmap_size
+ for i in range(num_levels):
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
+ params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
+ offsets.append(offset)
+ offset += params_in_level
+ offsets.append(offset)
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
+ self.register_buffer('offsets', offsets)
+
+ self.n_params = offsets[-1] * level_dim
+
+ # parameters
+ self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ std = 1e-4
+ self.embeddings.data.uniform_(-std, std)
+
+ def __repr__(self):
+ return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
+
+ def forward(self, inputs, bound=1):
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
+ # return: [..., num_levels * level_dim]
+
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
+
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
+
+ prefix_shape = list(inputs.shape[:-1])
+ inputs = inputs.view(-1, self.input_dim)
+
+ outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id)
+ outputs = outputs.view(prefix_shape + [self.output_dim])
+
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
+
+ return outputs
+
+ # always run in float precision!
+ @torch.cuda.amp.autocast(enabled=False)
+ def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):
+ # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.
+
+ D = self.input_dim
+ C = self.embeddings.shape[1] # embedding dim for each level
+ L = self.offsets.shape[0] - 1 # level
+ S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
+ H = self.base_resolution # base resolution
+
+ if inputs is None:
+ # randomized in [0, 1]
+ inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)
+ else:
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
+ inputs = inputs.view(-1, self.input_dim)
+ B = inputs.shape[0]
+
+ if self.embeddings.grad is None:
+ raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
+
+ _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)
\ No newline at end of file
diff --git a/core/lib/gridencoder/setup.py b/core/lib/gridencoder/setup.py
new file mode 100755
index 0000000..714bf1c
--- /dev/null
+++ b/core/lib/gridencoder/setup.py
@@ -0,0 +1,50 @@
+import os
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+setup(
+ name='gridencoder', # package name, import this to use python API
+ ext_modules=[
+ CUDAExtension(
+ name='_gridencoder', # extension name, import this to use CUDA API
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'gridencoder.cu',
+ 'bindings.cpp',
+ ]],
+ extra_compile_args={
+ 'cxx': c_flags,
+ 'nvcc': nvcc_flags,
+ }
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
\ No newline at end of file
diff --git a/core/lib/gridencoder/src/bindings.cpp b/core/lib/gridencoder/src/bindings.cpp
new file mode 100755
index 0000000..93dea94
--- /dev/null
+++ b/core/lib/gridencoder/src/bindings.cpp
@@ -0,0 +1,9 @@
+#include
+
+#include "gridencoder.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
+ m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
+ m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
+}
\ No newline at end of file
diff --git a/core/lib/gridencoder/src/gridencoder.cu b/core/lib/gridencoder/src/gridencoder.cu
new file mode 100755
index 0000000..fdd49cb
--- /dev/null
+++ b/core/lib/gridencoder/src/gridencoder.cu
@@ -0,0 +1,642 @@
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+
+// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!
+ __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {
+ // requires CUDA >= 10 and ARCH >= 70
+ // this is very slow compared to float or __half2, never use it.
+ //return atomicAdd(reinterpret_cast<__half*>(address), val);
+}
+
+
+template
+__host__ __device__ inline T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+template
+__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) {
+ return min(max(v, lo), hi);
+}
+
+template
+__device__ inline T smoothstep(T val) {
+ return val*val*(3.0f - 2.0f * val);
+}
+
+template
+__device__ inline T smoothstep_derivative(T val) {
+ return 6*val*(1.0f - val);
+}
+
+
+template
+__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
+
+ // coherent type of hashing
+ constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };
+
+ uint32_t result = 0;
+ #pragma unroll
+ for (uint32_t i = 0; i < D; ++i) {
+ result ^= pos_grid[i] * primes[i];
+ }
+
+ return result;
+}
+
+
+template
+__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
+ uint32_t stride = 1;
+ uint32_t index = 0;
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
+ index += pos_grid[d] * stride;
+ stride *= align_corners ? resolution: (resolution + 1);
+ }
+
+ // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
+ // gridtype: 0 == hash, 1 == tiled
+ if (gridtype == 0 && stride > hashmap_size) {
+ index = fast_hash(pos_grid);
+ }
+
+ return (index % hashmap_size) * C + ch;
+}
+
+
+template
+__global__ void kernel_grid(
+ const float * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ const int * __restrict__ offsets,
+ scalar_t * __restrict__ outputs,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ scalar_t * __restrict__ dy_dx,
+ const uint32_t gridtype,
+ const bool align_corners,
+ const uint32_t interp
+) {
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+
+ // locate
+ grid += (uint32_t)offsets[level] * C;
+ inputs += b * D;
+ outputs += level * B * C + b * C;
+
+ // check input range (should be in [0, 1])
+ bool flag_oob = false;
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ flag_oob = true;
+ }
+ }
+ // if input out of bound, just set output to 0
+ if (flag_oob) {
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ outputs[ch] = 0;
+ }
+ if (dy_dx) {
+ dy_dx += b * D * L * C + level * D * C; // B L D C
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ dy_dx[d * C + ch] = 0;
+ }
+ }
+ }
+ return;
+ }
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const float scale = exp2f(level * S) * H - 1.0f;
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
+
+ // calculate coordinate (always use float for precision!)
+ float pos[D];
+ float pos_deriv[D] = {1.0f}; // linear deriv is default to 1
+ uint32_t pos_grid[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
+ pos_grid[d] = floorf(pos[d]);
+ pos[d] -= (float)pos_grid[d];
+ // smoothstep instead of linear
+ if (interp == 1) {
+ pos_deriv[d] = smoothstep_derivative(pos[d]);
+ pos[d] = smoothstep(pos[d]);
+ }
+ }
+
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
+
+ // interpolate
+ scalar_t results[C] = {0}; // temp results in register
+
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
+ float w = 1;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if ((idx & (1 << d)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
+
+ // writing to register (fast)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ results[ch] += w * grid[index + ch];
+ }
+
+ //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
+ }
+
+ // writing to global memory (slow)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ outputs[ch] = results[ch];
+ }
+
+ // prepare dy_dx
+ // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
+ if (dy_dx) {
+
+ dy_dx += b * D * L * C + level * D * C; // B L D C
+
+ #pragma unroll
+ for (uint32_t gd = 0; gd < D; gd++) {
+
+ scalar_t results_grad[C] = {0};
+
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
+ float w = scale;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
+
+ if ((idx & (1 << nd)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ pos_grid_local[gd] = pos_grid[gd];
+ uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
+ pos_grid_local[gd] = pos_grid[gd] + 1;
+ uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];
+ }
+ }
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ dy_dx[gd * C + ch] = results_grad[ch];
+ }
+ }
+ }
+}
+
+
+template
+__global__ void kernel_grid_backward(
+ const scalar_t * __restrict__ grad,
+ const float * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ const int * __restrict__ offsets,
+ scalar_t * __restrict__ grad_grid,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ const uint32_t gridtype,
+ const bool align_corners,
+ const uint32_t interp
+) {
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
+
+ // locate
+ grad_grid += offsets[level] * C;
+ inputs += b * D;
+ grad += level * B * C + b * C + ch; // L, B, C
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const float scale = exp2f(level * S) * H - 1.0f;
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
+
+ // check input range (should be in [0, 1])
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ return; // grad is init as 0, so we simply return.
+ }
+ }
+
+ // calculate coordinate
+ float pos[D];
+ uint32_t pos_grid[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
+ pos_grid[d] = floorf(pos[d]);
+ pos[d] -= (float)pos_grid[d];
+ // smoothstep instead of linear
+ if (interp == 1) {
+ pos[d] = smoothstep(pos[d]);
+ }
+ }
+
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c++) {
+ grad_cur[c] = grad[c];
+ }
+
+ // interpolate
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
+ float w = 1;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if ((idx & (1 << d)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
+
+ // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
+ // TODO: use float which is better than __half, if N_C % 2 != 0
+ if (std::is_same::value && N_C % 2 == 0) {
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c += 2) {
+ // process two __half at once (by interpreting as a __half2)
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
+ atomicAdd((__half2*)&grad_grid[index + c], v);
+ }
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
+ } else {
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c++) {
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
+ }
+ }
+ }
+}
+
+
+template
+__global__ void kernel_input_backward(
+ const scalar_t * __restrict__ grad,
+ const scalar_t * __restrict__ dy_dx,
+ scalar_t * __restrict__ grad_inputs,
+ uint32_t B, uint32_t L
+) {
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * D) return;
+
+ const uint32_t b = t / D;
+ const uint32_t d = t - b * D;
+
+ dy_dx += b * L * D * C;
+
+ scalar_t result = 0;
+
+ # pragma unroll
+ for (int l = 0; l < L; l++) {
+ # pragma unroll
+ for (int ch = 0; ch < C; ch++) {
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
+ }
+ }
+
+ grad_inputs[t] = result;
+}
+
+
+template
+void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ static constexpr uint32_t N_THREAD = 512;
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
+ switch (C) {
+ case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
+// H: base resolution
+// dy_dx: [B, L * D * C]
+template
+void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ switch (D) {
+ case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+template
+void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ static constexpr uint32_t N_THREAD = 256;
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
+ const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
+ switch (C) {
+ case 1:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 2:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 4:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 8:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+
+// grad: [L, B, C], float
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// grad_embeddings: [sO, C]
+// H: base resolution
+template
+void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ switch (D) {
+ case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
+ case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
+ case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
+ case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+
+
+void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(embeddings);
+ CHECK_CUDA(offsets);
+ CHECK_CUDA(outputs);
+ // CHECK_CUDA(dy_dx);
+
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(embeddings);
+ CHECK_CONTIGUOUS(offsets);
+ CHECK_CONTIGUOUS(outputs);
+ // CHECK_CONTIGUOUS(dy_dx);
+
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(embeddings);
+ CHECK_IS_INT(offsets);
+ CHECK_IS_FLOATING(outputs);
+ // CHECK_IS_FLOATING(dy_dx);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
+ grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp);
+ }));
+}
+
+void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ CHECK_CUDA(grad);
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(embeddings);
+ CHECK_CUDA(offsets);
+ CHECK_CUDA(grad_embeddings);
+ // CHECK_CUDA(dy_dx);
+ // CHECK_CUDA(grad_inputs);
+
+ CHECK_CONTIGUOUS(grad);
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(embeddings);
+ CHECK_CONTIGUOUS(offsets);
+ CHECK_CONTIGUOUS(grad_embeddings);
+ // CHECK_CONTIGUOUS(dy_dx);
+ // CHECK_CONTIGUOUS(grad_inputs);
+
+ CHECK_IS_FLOATING(grad);
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(embeddings);
+ CHECK_IS_INT(offsets);
+ CHECK_IS_FLOATING(grad_embeddings);
+ // CHECK_IS_FLOATING(dy_dx);
+ // CHECK_IS_FLOATING(grad_inputs);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad.scalar_type(), "grid_encode_backward", ([&] {
+ grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp);
+ }));
+
+}
+
+
+template
+__global__ void kernel_grad_tv(
+ const scalar_t * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ scalar_t * __restrict__ grad,
+ const int * __restrict__ offsets,
+ const float weight,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ const uint32_t gridtype,
+ const bool align_corners
+) {
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+
+ // locate
+ inputs += b * D;
+ grid += (uint32_t)offsets[level] * C;
+ grad += (uint32_t)offsets[level] * C;
+
+ // check input range (should be in [0, 1])
+ bool flag_oob = false;
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ flag_oob = true;
+ }
+ }
+
+ // if input out of bound, do nothing
+ if (flag_oob) return;
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const float scale = exp2f(level * S) * H - 1.0f;
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
+
+ // calculate coordinate
+ float pos[D];
+ uint32_t pos_grid[D]; // [0, resolution]
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
+ pos_grid[d] = floorf(pos[d]);
+ // pos[d] -= (float)pos_grid[d]; // not used
+ }
+
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
+
+ // total variation on pos_grid
+ scalar_t results[C] = {0}; // temp results in register
+ scalar_t idelta[C] = {0};
+
+ uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
+
+ scalar_t w = weight / (2 * D);
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+
+ uint32_t cur_d = pos_grid[d];
+ scalar_t grad_val;
+
+ // right side
+ if (cur_d < resolution) {
+ pos_grid[d] = cur_d + 1;
+ uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f);
+ grad_val = (grid[index + ch] - grid[index_right + ch]);
+ results[ch] += grad_val;
+ idelta[ch] += grad_val * grad_val;
+ }
+ }
+
+ // left side
+ if (cur_d > 0) {
+ pos_grid[d] = cur_d - 1;
+ uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f);
+ grad_val = (grid[index + ch] - grid[index_left + ch]);
+ results[ch] += grad_val;
+ idelta[ch] += grad_val * grad_val;
+ }
+ }
+
+ // reset
+ pos_grid[d] = cur_d;
+ }
+
+ // writing to global memory (slow)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ // index may collide, so use atomic!
+ atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));
+ }
+
+}
+
+
+template
+void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
+ static constexpr uint32_t N_THREAD = 512;
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
+ switch (C) {
+ case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+
+template
+void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
+ switch (D) {
+ case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
+ case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
+ case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
+ case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+
+void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ embeddings.scalar_type(), "grad_total_variation", ([&] {
+ grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners);
+ }));
+}
\ No newline at end of file
diff --git a/core/lib/gridencoder/src/gridencoder.h b/core/lib/gridencoder/src/gridencoder.h
new file mode 100755
index 0000000..1b38575
--- /dev/null
+++ b/core/lib/gridencoder/src/gridencoder.h
@@ -0,0 +1,17 @@
+#ifndef _HASH_ENCODE_H
+#define _HASH_ENCODE_H
+
+#include
+#include
+
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// outputs: [B, L * C], float
+// H: base resolution
+void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
+void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
+
+void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
+
+#endif
\ No newline at end of file
diff --git a/core/lib/guidance.py b/core/lib/guidance.py
new file mode 100755
index 0000000..62fec7a
--- /dev/null
+++ b/core/lib/guidance.py
@@ -0,0 +1,390 @@
+from transformers import CLIPTextModel, CLIPTokenizer, logging
+from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, ControlNetModel
+
+# suppress partial model loading warning
+logging.set_verbosity_error()
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as T
+
+import numpy as np
+import PIL
+
+from torch.cuda.amp import custom_bwd, custom_fwd
+import clip
+
+class SpecifyGradient(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input_tensor, gt_grad):
+ ctx.save_for_backward(gt_grad)
+ return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype) # dummy loss value
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad):
+ gt_grad, = ctx.saved_tensors
+ batch_size = len(gt_grad)
+ return gt_grad / batch_size, None
+
+def seed_everything(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ #torch.backends.cudnn.deterministic = True
+ #torch.backends.cudnn.benchmark = True
+
+class StableDiffusion(nn.Module):
+ def __init__(self, device, sd_version='2.1', hf_key=None, sd_step_range=[0.2, 0.98], controlnet=None, lora=None, cfg=None, head_hf_key=None):
+ super().__init__()
+ self.cfg = cfg
+ self.device = device
+ self.sd_version = sd_version
+
+ print(f'[INFO] loading stable diffusion...')
+
+ if hf_key is not None:
+ print(f'[INFO] using hugging face custom model key: {hf_key}')
+ model_key = hf_key
+ elif self.sd_version == '2.1':
+ model_key = "stabilityai/stable-diffusion-2-1-base"
+ elif self.sd_version == '2.0':
+ model_key = "stabilityai/stable-diffusion-2-base"
+ elif self.sd_version == '1.5':
+ model_key = "runwayml/stable-diffusion-v1-5"
+ else:
+ raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
+ self.clip_model, _ = clip.load("ViT-L/14", device=self.device, jit=False, download_root='clip_ckpts')
+ self.clip_model = self.clip_model.eval().requires_grad_(False).to(self.device)
+ self.clip_preprocess = T.Compose([
+ T.Resize((224, 224)),
+ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+ # Create model
+ self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
+ self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
+
+ self.use_head_model = head_hf_key is not None
+ if self.use_head_model:
+ self.tokenizer_head = CLIPTokenizer.from_pretrained(head_hf_key, subfolder="tokenizer")
+ self.text_encoder_head = CLIPTextModel.from_pretrained(head_hf_key, subfolder="text_encoder").to(self.device)
+ self.unet_head = UNet2DConditionModel.from_pretrained(head_hf_key, subfolder="unet").to(self.device)
+ else:
+ self.tokenizer_head = self.tokenizer
+ self.text_encoder_head = self.text_encoder
+ self.unet_head = self.unet
+
+ self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
+
+ self.num_train_timesteps = self.scheduler.config.num_train_timesteps
+ self.min_step = int(self.num_train_timesteps * sd_step_range[0])
+ self.max_step = int(self.num_train_timesteps * sd_step_range[1])
+ self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
+
+ if controlnet is None:
+ self.controlnet = None
+ else:
+ self.controlnet = ControlNetModel.from_pretrained(controlnet).to(self.device)
+
+ if lora is not None:
+ self.unet.load_attn_procs(lora)
+
+ print(f'[INFO] loaded stable diffusion!')
+
+ def img_clip_loss(self, rgb1, rgb2):
+ image_z_1 = self.clip_model.encode_image(self.clip_preprocess(rgb1))
+ image_z_2 = self.clip_model.encode_image(self.clip_preprocess(rgb2))
+ image_z_1 = image_z_1 / image_z_1.norm(dim=-1, keepdim=True) # normalize features
+ image_z_2 = image_z_2 / image_z_2.norm(dim=-1, keepdim=True) # normalize features
+
+ loss = - (image_z_1 * image_z_2).sum(-1).mean()
+ return loss
+
+ def img_text_clip_loss(self, rgb, prompts):
+ image_z_1 = self.clip_model.encode_image(self.aug(rgb))
+ image_z_1 = image_z_1 / image_z_1.norm(dim=-1, keepdim=True) # normalize features
+
+ text = clip.tokenize(prompt).to(self.device)
+ text_z = self.clip_model.encode_text(text)
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
+ loss = - (image_z_1 * text_z).sum(-1).mean()
+ return loss
+
+ def get_text_embeds(self, prompt, negative_prompt, is_face=False):
+ print('text prompt: [positive]', prompt, '[negative]', negative_prompt)
+ if not is_face:
+ tokenizer = self.tokenizer
+ text_encoder = self.text_encoder
+ else:
+ tokenizer = self.tokenizer_head
+ text_encoder = self.text_encoder_head
+ # prompt, negative_prompt: [str]
+
+ # Tokenize text and get embeddings
+ text_input = tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
+
+ with torch.no_grad():
+ text_embeddings = text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # Do the same for unconditional embeddings
+ uncond_input = tokenizer(negative_prompt, padding='max_length', max_length=tokenizer.model_max_length, return_tensors='pt')
+
+ with torch.no_grad():
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # Cat for final embeddings
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+ return text_embeddings
+
+
+ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, controlnet_hint=None, controlnet_conditioning_scale=1.0, clip_ref_img=None, is_face=False, **kwargs):
+
+ if is_face:
+ unet = self.unet_head
+ else:
+ unet = self.unet
+ # interp to 512x512 to be fed into vae.
+
+ # _t = time.time()
+ pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
+ #pred_rgb_512 = pred_rgb
+ if controlnet_hint:
+ assert self.controlnet is not None
+ controlnet_hint = self.controlnet_hint_conversion(controlnet_hint, 512, 512)
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
+
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
+ t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=self.device)
+
+ # encode image into latents with vae, requires grad!
+ # _t = time.time()
+ latents = self.encode_imgs(pred_rgb_512)
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
+
+ # predict the noise residual with unet, NO grad!
+ # _t = time.time()
+ with torch.no_grad():
+ # add noise
+ noise = torch.randn_like(latents)
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
+ # pred noise
+ latent_model_input = torch.cat([latents_noisy] * 2)
+ if controlnet_hint is not None:
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=text_embeddings,
+ controlnet_cond=controlnet_hint,
+ conditioning_scale=controlnet_conditioning_scale,
+ return_dict=False
+ )
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,).sample
+ else:
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
+
+ # perform guidance (high scale from paper!)
+ if self.scheduler.config.prediction_type == "v_prediction":
+ alphas_cumprod = self.scheduler.alphas_cumprod.to(
+ device=latents_noisy.device, dtype=latents_noisy.dtype
+ )
+ alpha_t = alphas_cumprod[t] ** 0.5
+ sigma_t = (1 - alphas_cumprod[t]) ** 0.5
+
+ noise_pred = latent_model_input * torch.cat([sigma_t] * 2, dim=0).view(
+ -1, 1, 1, 1
+ ) + noise_pred * torch.cat([alpha_t] * 2, dim=0).view(-1, 1, 1, 1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+
+ if clip_ref_img is not None and t < self.cfg.clip_step_range * self.num_train_timesteps:
+
+ guidance_scale = self.cfg.clip_guidance_scale
+ noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ self.scheduler.set_timesteps(self.num_train_timesteps)
+ de_latents = self.scheduler.step(noise_pred, t, latents_noisy)['prev_sample']
+ imgs = self.decode_latents(de_latents)
+ loss = 0
+ if self.cfg.lambda_clip_img_loss > 0:
+ loss = loss + self.img_clip_loss(imgs, clip_ref_img) * self.cfg.lambda_clip_img_loss
+ if self.cfg.lambda_clip_text_loss > 0:
+ text = self.cfg.text.replace('sks', '')
+ loss = loss + self.img_text_clip_loss(imgs, [text]) * self.cfg.lambda_clip_text_loss
+
+ else:
+ noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # w(t), sigma_t^2
+ w = (1 - self.alphas[t])
+ # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
+ grad = w * (noise_pred - noise)
+
+ # clip grad for stable training?
+ # grad = grad.clamp(-10, 10)
+ grad = torch.nan_to_num(grad)
+
+ # since we omitted an item in grad, we need to use the custom function to specify the gradient
+ # _t = time.time()
+ loss = SpecifyGradient.apply(latents, grad)
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
+
+ return loss
+
+ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
+
+ if latents is None:
+ latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ with torch.autocast('cuda'):
+ for i, t in enumerate(self.scheduler.timesteps):
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
+ latent_model_input = torch.cat([latents] * 2)
+
+ # predict the noise residual
+ with torch.no_grad():
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
+
+ # perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
+
+ return latents
+
+ def decode_latents(self, latents):
+
+ latents = 1 / 0.18215 * latents
+
+ with torch.no_grad():
+ imgs = self.vae.decode(latents).sample
+
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
+
+ return imgs
+
+ def encode_imgs(self, imgs):
+ # imgs: [B, 3, H, W]
+
+ imgs = 2 * imgs - 1
+
+ posterior = self.vae.encode(imgs).latent_dist
+ latents = posterior.sample() * 0.18215
+
+ return latents
+
+ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
+
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ if isinstance(negative_prompts, str):
+ negative_prompts = [negative_prompts]
+
+ # Prompts -> text embeds
+ text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
+
+ # Text embeds -> img latents
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
+
+ # Img latents -> imgs
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
+
+ # Img to Numpy
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
+ imgs = (imgs * 255).round().astype('uint8')
+
+ return imgs
+
+
+ def controlnet_hint_conversion(self, controlnet_hint, height, width, num_images_per_prompt=1):
+ channels = 3
+ if isinstance(controlnet_hint, torch.Tensor):
+ # torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
+ shape_chw = (channels, height, width)
+ shape_bchw = (1, channels, height, width)
+ shape_nchw = (num_images_per_prompt, channels, height, width)
+ if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
+ controlnet_hint = controlnet_hint.to(dtype=self.controlnet.dtype, device=self.controlnet.device)
+ if controlnet_hint.shape != shape_nchw:
+ controlnet_hint = controlnet_hint.repeat(num_images_per_prompt, 1, 1, 1)
+ return controlnet_hint
+ else:
+ raise ValueError(
+ f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width}),"
+ + f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ + f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
+ )
+ elif isinstance(controlnet_hint, np.ndarray):
+ # np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
+ # hwc is opencv compatible image format. Color channel must be BGR Format.
+ if controlnet_hint.shape == (height, width):
+ controlnet_hint = np.repeat(controlnet_hint[:, :, np.newaxis], channels, axis=2) # hw -> hwc(c==3)
+ shape_hwc = (height, width, channels)
+ shape_bhwc = (1, height, width, channels)
+ shape_nhwc = (num_images_per_prompt, height, width, channels)
+ if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
+ controlnet_hint = torch.from_numpy(controlnet_hint.copy())
+ controlnet_hint = controlnet_hint.to(dtype=self.controlnet.dtype, device=self.controlnet.device)
+ controlnet_hint /= 255.0
+ if controlnet_hint.shape != shape_nhwc:
+ controlnet_hint = controlnet_hint.repeat(num_images_per_prompt, 1, 1, 1)
+ controlnet_hint = controlnet_hint.permute(0, 3, 1, 2) # b h w c -> b c h w
+ return controlnet_hint
+ else:
+ raise ValueError(
+ f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), "
+ + f"({height}, {width}, {channels}), "
+ + f"(1, {height}, {width}, {channels}) or "
+ + f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
+ )
+ elif isinstance(controlnet_hint, PIL.Image.Image):
+ if controlnet_hint.size == (width, height):
+ controlnet_hint = controlnet_hint.convert("RGB") # make sure 3 channel RGB format
+ controlnet_hint = np.array(controlnet_hint) # to numpy
+ controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
+ return self.controlnet_hint_conversion(controlnet_hint, height, width, num_images_per_prompt)
+ else:
+ raise ValueError(
+ f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint.size}"
+ )
+ else:
+ raise ValueError(
+ f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
+ )
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import matplotlib.pyplot as plt
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('prompt', type=str)
+ parser.add_argument('--negative', default='', type=str)
+ parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
+ parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
+ parser.add_argument('-H', type=int, default=512)
+ parser.add_argument('-W', type=int, default=512)
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--steps', type=int, default=50)
+ opt = parser.parse_args()
+
+ seed_everything(opt.seed)
+
+ device = torch.device('cuda')
+
+ sd = StableDiffusion(device, opt.sd_version, opt.hf_key)
+ # visualize image
+
+ plt.show()
+ imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
+
+ plt.imshow(imgs[0])
\ No newline at end of file
diff --git a/core/lib/hed_annotator.py b/core/lib/hed_annotator.py
new file mode 100755
index 0000000..f267f82
--- /dev/null
+++ b/core/lib/hed_annotator.py
@@ -0,0 +1,128 @@
+import numpy as np
+import cv2
+import os
+import torch
+from einops import rearrange
+
+
+class Network(torch.nn.Module):
+ def __init__(self, model_path):
+ super().__init__()
+
+ self.netVggOne = torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netVggTwo = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netVggThr = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netVggFou = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netVggFiv = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+
+ self.netCombine = torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
+ torch.nn.Sigmoid()
+ )
+
+ self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
+
+ def forward(self, tenInput):
+ tenInput = tenInput * 255.0
+ tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
+
+ tenVggOne = self.netVggOne(tenInput)
+ tenVggTwo = self.netVggTwo(tenVggOne)
+ tenVggThr = self.netVggThr(tenVggTwo)
+ tenVggFou = self.netVggFou(tenVggThr)
+ tenVggFiv = self.netVggFiv(tenVggFou)
+
+ tenScoreOne = self.netScoreOne(tenVggOne)
+ tenScoreTwo = self.netScoreTwo(tenVggTwo)
+ tenScoreThr = self.netScoreThr(tenVggThr)
+ tenScoreFou = self.netScoreFou(tenVggFou)
+ tenScoreFiv = self.netScoreFiv(tenVggFiv)
+
+ tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+
+ return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
+
+
+class HEDdetector:
+ def __init__(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
+ annotator_ckpts_path = 'ckpts'
+ modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+ self.netNetwork = Network(modelpath).cuda().eval()
+
+ def __call__(self, image_hed):
+ with torch.no_grad():
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+ edge = self.netNetwork(image_hed)[0]
+ edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
+ return edge[0]
+
+
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+ y = np.zeros_like(x)
+
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
diff --git a/core/lib/loss_utils.py b/core/lib/loss_utils.py
new file mode 100755
index 0000000..2fc9199
--- /dev/null
+++ b/core/lib/loss_utils.py
@@ -0,0 +1,172 @@
+
+import torch
+import torch.nn.functional as F
+from scipy.ndimage import distance_transform_edt
+
+def crop_by_mask(rgb, alpha, base_size=64):
+ mask = (alpha[0,0] > 0).float()
+ h, w = mask.shape
+ y = torch.arange(0, h, dtype=torch.float).to(mask)
+ x = torch.arange(0, w, dtype=torch.float).to(mask)
+ y, x = torch.meshgrid(y, x)
+ x_max = int((mask * x).view(-1).max(-1)[0])
+ x_min = int(w - (mask * (w-x)).view(-1).max(-1)[0])
+ y_max = int((mask * y).view(-1).max(-1)[0])
+ y_min = int(h - (mask * (h-y)).view(-1).max(-1)[0])
+ if (x_max - x_min) % base_size > 0:
+ x_max = min(x_max + base_size - ((x_max - x_min) % base_size), w-1)
+ if (x_max - x_min) % base_size > 0:
+ x_min = max(x_min - base_size + ((x_max - x_min) % base_size), 0)
+ if (y_max - y_min) % base_size > 0:
+ y_max = min(y_max + base_size - ((y_max - y_min) % base_size), h-1)
+ if (y_max - y_min) % base_size > 0:
+ y_min = max(y_min - base_size + ((y_max - y_min) % base_size), 0)
+ #print(y_min, y_max, x_min, x_max)
+ return rgb[:, :, y_min:y_max, x_min:x_max], alpha[:, :, y_min:y_max, x_min:x_max]
+
+def silhouette_loss(alpha, gt_mask, edt=None, loss_mask=None, kernel_size=7, edt_power=0.25, l2_weight=0.01, edge_weight=0.01):
+ """
+ Inputs:
+ alpha: Bx1xHxW Tensor, predicted alpha,
+ gt_mask: Bx1xHxW Tensor, ground-truth mask
+ loss_mask[Optional]: Bx1xHxW Tensor, loss mask, calculate loss inside the mask only
+ kernel_size: edge filter kernel size
+ edt_power: edge distance power in the loss
+ l2_weight: loss weight of the l2 loss
+ edge_weight: loss weight of the edge loss
+ Output:
+ loss
+ """
+ sil_l2loss = (gt_mask - alpha) ** 2
+ if loss_mask is not None:
+ sil_l2loss = sil_l2loss * loss_mask
+ def compute_edge(x):
+ return F.max_pool2d(x, kernel_size, 1, kernel_size // 2) - x
+ if edt is None:
+ gt_edge = compute_edge(gt_mask).cpu().numpy()
+ edt = torch.tensor(distance_transform_edt(1 - (gt_edge > 0)) ** (edt_power * 2), dtype=torch.float32, device=gt_mask.device)
+ if loss_mask is not None:
+ pred_edge = pred_edge * loss_mask
+ pred_edge = compute_edge(alpha)
+ sil_edgeloss = torch.sum(pred_edge * edt.to(pred_edge.device)) / (pred_edge.sum()+1e-7)
+ return sil_l2loss.mean() * l2_weight + sil_edgeloss * edge_weight
+
+def get_edt(gt_mask, loss_mask=None, kernel_size=7, edt_power=0.25, l2_weight=0.01, edge_weight=0.01):
+ def compute_edge(x):
+ return F.max_pool2d(x, kernel_size, 1, kernel_size // 2) - x
+ gt_edge = compute_edge(gt_mask).cpu().numpy()
+ edt = torch.tensor(distance_transform_edt(1 - (gt_edge > 0)) ** (edt_power * 2), dtype=torch.float32, device=gt_mask.device)
+ return edt
+
+
+def laplacian_cot(verts, faces):
+ """
+ Compute the cotangent laplacian
+ Inspired by https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/loss/mesh_laplacian_smoothing.html
+ Parameters
+ ----------
+ verts : torch.Tensor
+ Vertex positions.
+ faces : torch.Tensor
+ array of triangle faces.
+ """
+
+ # V = sum(V_n), F = sum(F_n)
+ V, F = verts.shape[0], faces.shape[0]
+
+ face_verts = verts[faces]
+ v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
+
+ # Side lengths of each triangle, of shape (sum(F_n),)
+ # A is the side opposite v1, B is opposite v2, and C is opposite v3
+ A = (v1 - v2).norm(dim=1)
+ B = (v0 - v2).norm(dim=1)
+ C = (v0 - v1).norm(dim=1)
+
+ # Area of each triangle (with Heron's formula); shape is (sum(F_n),)
+ s = 0.5 * (A + B + C)
+ # note that the area can be negative (close to 0) causing nans after sqrt()
+ # we clip it to a small positive value
+ area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
+
+ # Compute cotangents of angles, of shape (sum(F_n), 3)
+ A2, B2, C2 = A * A, B * B, C * C
+ cota = (B2 + C2 - A2) / area
+ cotb = (A2 + C2 - B2) / area
+ cotc = (A2 + B2 - C2) / area
+ cot = torch.stack([cota, cotb, cotc], dim=1)
+ cot /= 4.0
+
+ # Construct a sparse matrix by basically doing:
+ # L[v1, v2] = cota
+ # L[v2, v0] = cotb
+ # L[v0, v1] = cotc
+ ii = faces[:, [1, 2, 0]]
+ jj = faces[:, [2, 0, 1]]
+ idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
+ L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
+
+ # Make it symmetric; this means we are also setting
+ # L[v2, v1] = cota
+ # L[v0, v2] = cotb
+ # L[v1, v0] = cotc
+ L += L.t()
+
+ # Add the diagonal indices
+ vals = torch.sparse.sum(L, dim=0).to_dense()
+ indices = torch.arange(V, device='cuda')
+ idx = torch.stack([indices, indices], dim=0)
+ L = torch.sparse.FloatTensor(idx, vals, (V, V)) - L
+ return L
+
+
+def laplacian_uniform(verts, faces):
+ """
+ Compute the uniform laplacian
+ Parameters
+ ----------
+ verts : torch.Tensor
+ Vertex positions.
+ faces : torch.Tensor
+ array of triangle faces.
+ """
+ V = verts.shape[0]
+ F = faces.shape[0]
+
+ # Neighbor indices
+ ii = faces[:, [1, 2, 0]].flatten()
+ jj = faces[:, [2, 0, 1]].flatten()
+ adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1)
+ adj_values = torch.ones(adj.shape[1], device='cuda', dtype=torch.float)
+
+ # Diagonal indices
+ diag_idx = adj[0]
+
+ # Build the sparse matrix
+ idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
+ values = torch.cat((-adj_values, adj_values))
+
+ # The coalesce operation sums the duplicate indices, resulting in the
+ # correct diagonal
+ return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
+
+def laplacian_smooth_loss(v_pos, t_pos_idx):
+ term = torch.zeros_like(v_pos)
+ norm = torch.zeros_like(v_pos[..., 0:1])
+
+ v0 = v_pos[t_pos_idx[:, 0], :]
+ v1 = v_pos[t_pos_idx[:, 1], :]
+ v2 = v_pos[t_pos_idx[:, 2], :]
+
+ term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1, 3), (v1 - v0) + (v2 - v0))
+ term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1, 3), (v0 - v1) + (v2 - v1))
+ term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1, 3), (v0 - v2) + (v1 - v2))
+
+ two = torch.ones_like(v0) * 2.0
+ norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
+ norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
+ norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
+
+ term = term / torch.clamp(norm, min=1.0)
+
+ return torch.mean(term ** 2)
\ No newline at end of file
diff --git a/core/lib/marching_tets.py b/core/lib/marching_tets.py
new file mode 100755
index 0000000..7aecefb
--- /dev/null
+++ b/core/lib/marching_tets.py
@@ -0,0 +1,145 @@
+import torch
+from torch import Tensor, nn
+import numpy as np
+
+
+###############################################################################
+# Marching tetrahedrons implementation (differentiable), adapted from
+# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
+###############################################################################
+class DMTet(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ triangle_table = torch.tensor([
+ [-1, -1, -1, -1, -1, -1],
+ [ 1, 0, 2, -1, -1, -1],
+ [ 4, 0, 3, -1, -1, -1],
+ [ 1, 4, 2, 1, 3, 4],
+ [ 3, 1, 5, -1, -1, -1],
+ [ 2, 3, 0, 2, 5, 3],
+ [ 1, 4, 0, 1, 5, 4],
+ [ 4, 2, 5, -1, -1, -1],
+ [ 4, 5, 2, -1, -1, -1],
+ [ 4, 1, 0, 4, 5, 1],
+ [ 3, 2, 0, 3, 5, 2],
+ [ 1, 3, 5, -1, -1, -1],
+ [ 4, 1, 2, 4, 3, 1],
+ [ 3, 0, 4, -1, -1, -1],
+ [ 2, 0, 1, -1, -1, -1],
+ [-1, -1, -1, -1, -1, -1],
+ ], dtype=torch.long) # yapf: disable
+
+
+ num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long)
+ base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long)
+
+ self.register_buffer('triangle_table', triangle_table, persistent=False)
+ self.register_buffer('num_triangles_table', num_triangles_table, persistent=False)
+ self.register_buffer('base_tet_edges', base_tet_edges, persistent=False)
+
+ ###############################################################################
+ # Utility functions
+ ###############################################################################
+
+ def sort_edges(self, edges_ex2):
+ with torch.no_grad():
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
+ order = order.unsqueeze(dim=1)
+
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
+
+ return torch.stack([a, b], -1)
+
+ def map_uv(self, faces, face_gidx, max_idx):
+ N = int(np.ceil(np.sqrt((max_idx + 1) // 2)))
+ tex_y, tex_x = torch.meshgrid(
+ torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=face_gidx.device),
+ torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=face_gidx.device),
+ indexing='ij')
+
+ pad = 0.9 / N
+
+ uvs = torch.stack([tex_x, tex_y, tex_x + pad, tex_y, tex_x + pad, tex_y + pad, tex_x, tex_y + pad],
+ dim=-1).view(-1, 2)
+
+ def _idx(tet_idx, N):
+ x = tet_idx % N
+ y = torch.div(tet_idx, N, rounding_mode='trunc')
+ return y * N + x
+
+ tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
+ tri_idx = face_gidx % 2
+
+ uv_idx = torch.stack((tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2), dim=-1).view(-1, 3)
+
+ return uvs, uv_idx
+
+ ###############################################################################
+ # Marching tets implementation
+ ###############################################################################
+
+ def __call__(self, pos_nx3: Tensor, sdf_n: Tensor, tet_fx4: Tensor, with_uv: bool=True):
+ with torch.no_grad():
+ occ_n = sdf_n > 0
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
+ occ_sum = torch.sum(occ_fx4, -1)
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
+ occ_sum = occ_sum[valid_tets]
+
+ # find all vertices
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
+ all_edges = self.sort_edges(all_edges)
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
+
+ unique_edges = unique_edges.long()
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device="cuda")
+ idx_map = mapping[idx_map] # map edges to verts
+
+ interp_v = unique_edges[mask_edges]
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
+ edges_to_interp_sdf[:, -1] *= -1
+
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
+
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
+
+ idx_map = idx_map.reshape(-1, 6)
+
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
+ num_triangles = self.num_triangles_table[tetindex]
+
+ # Generate triangle indices
+ faces = torch.cat(
+ (
+ torch.gather(
+ input=idx_map[num_triangles == 1],
+ dim=1,
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
+ torch.gather(
+ input=idx_map[num_triangles == 2],
+ dim=1,
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
+ ),
+ dim=0,
+ )
+ if not with_uv:
+ return verts, faces
+
+ # Get global face index (static, does not depend on topology)
+ num_tets = tet_fx4.shape[0]
+ tet_gidx = torch.arange(num_tets, dtype=torch.long, device=tet_fx4.device)[valid_tets]
+ face_gidx = torch.cat(
+ (tet_gidx[num_triangles == 1] * 2,
+ torch.stack((tet_gidx[num_triangles == 2] * 2, tet_gidx[num_triangles == 2] * 2 + 1), dim=-1).view(-1)),
+ dim=0,
+ )
+ uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets * 2)
+
+ return verts, faces, uvs, uv_idx
diff --git a/core/lib/network_utils.py b/core/lib/network_utils.py
new file mode 100755
index 0000000..727f765
--- /dev/null
+++ b/core/lib/network_utils.py
@@ -0,0 +1,208 @@
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+# Positional Encoding from https://github.com/yenchenlin/nerf-pytorch/blob/1f064835d2cca26e4df2d7d130daa39a8cee1795/run_nerf_helpers.py
+class Embedder:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.create_embedding_fn()
+
+ def create_embedding_fn(self):
+ embed_fns = []
+ d = self.kwargs['input_dims']
+ out_dim = 0
+ if self.kwargs['include_input']:
+ embed_fns.append(lambda x : x)
+ out_dim += d
+
+ max_freq = self.kwargs['max_freq_log2']
+ N_freqs = self.kwargs['num_freqs']
+
+ if self.kwargs['log_sampling']:
+ freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
+ else:
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
+
+ for freq in freq_bands:
+ for p_fn in self.kwargs['periodic_fns']:
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
+ out_dim += d
+
+ self.embed_fns = embed_fns
+ self.out_dim = out_dim
+
+ def embed(self, inputs):
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
+
+def get_embedder(multires):
+ embed_kwargs = {
+ 'include_input' : True,
+ 'input_dims' : 3,
+ 'max_freq_log2' : multires-1,
+ 'num_freqs' : multires,
+ 'log_sampling' : True,
+ 'periodic_fns' : [torch.sin, torch.cos],
+ }
+
+ embedder_obj = Embedder(**embed_kwargs)
+ embed = lambda x, eo=embedder_obj : eo.embed(x)
+ return embed, embedder_obj.out_dim
+
+
+class FreqEncoder_torch(nn.Module):
+ def __init__(self, input_dim, max_freq_log2, N_freqs,
+ log_sampling=True, include_input=True,
+ periodic_fns=(torch.sin, torch.cos)):
+
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.include_input = include_input
+ self.periodic_fns = periodic_fns
+
+ self.output_dim = 0
+ if self.include_input:
+ self.output_dim += self.input_dim
+
+ self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
+
+ if log_sampling:
+ self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
+ else:
+ self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)
+
+ self.freq_bands = self.freq_bands.numpy().tolist()
+
+ def forward(self, input, **kwargs):
+
+ out = []
+ if self.include_input:
+ out.append(input)
+
+ for i in range(len(self.freq_bands)):
+ freq = self.freq_bands[i]
+ for p_fn in self.periodic_fns:
+ out.append(p_fn(input * freq))
+
+ out = torch.cat(out, dim=-1)
+
+ return out
+
+def get_encoder(encoding, input_dim=3,
+ multires=6,
+ degree=4,
+ num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear',
+ **kwargs):
+
+ if encoding == 'None':
+ return lambda x, **kwargs: x, input_dim
+
+ elif encoding == 'frequency_torch':
+ encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
+
+ elif encoding == 'frequency': # CUDA implementation, faster than torch.
+ from .freqencoder import FreqEncoder
+ encoder = FreqEncoder(input_dim=input_dim, degree=multires)
+
+ elif encoding == 'sphere_harmonics':
+ from shencoder import SHEncoder
+ encoder = SHEncoder(input_dim=input_dim, degree=degree)
+
+ elif encoding == 'hashgrid':
+ from .gridencoder import GridEncoder
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation)
+
+ elif encoding == 'tiledgrid':
+ from .gridencoder import GridEncoder
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation)
+
+ else:
+ raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
+
+ return encoder, encoder.output_dim
+
+
+# MLP + Positional Encoding
+class Decoder(torch.nn.Module):
+ def __init__(self, input_dims = 3, internal_dims = 128, output_dims = 4, hidden = 8, multires = 5):
+ super().__init__()
+ self.embed_fn = None
+ if multires > 0:
+ embed_fn, input_ch = get_embedder(multires)
+ self.embed_fn = embed_fn
+ input_dims = input_ch
+
+ net = (torch.nn.Linear(input_dims, internal_dims, bias=False), torch.nn.ReLU())
+ for i in range(hidden-1):
+ net = net + (torch.nn.Linear(internal_dims, internal_dims, bias=False), torch.nn.ReLU())
+ net = net + (torch.nn.Linear(internal_dims, output_dims, bias=False),)
+ self.net = torch.nn.Sequential(*net)
+
+ def forward(self, p):
+ if self.embed_fn is not None:
+ p = self.embed_fn(p)
+ out = self.net(p)
+ return out
+
+ def pre_train_sphere(self, iter, device='cuda', axis_scale=1.):
+ print ("Initialize SDF to sphere")
+ loss_fn = torch.nn.MSELoss()
+ optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-4)
+
+ for i in tqdm(range(iter)):
+ p = torch.rand((1024,3), device=device) - 0.5
+ p = p / axis_scale
+ ref_value = torch.sqrt((p**2).sum(-1)) - 0.3
+ output = self(p)
+ loss = loss_fn(output[...,0], ref_value)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ print("Pre-trained MLP", loss.item())
+
+
+class HashDecoder(nn.Module):
+ def __init__(self, input_dims = 3, internal_dims = 32, output_dims = 4, hidden = 2, input_bounds=None, max_res=1024, num_levels=16, interpolation='smoothstep') -> None:
+ super().__init__()
+ self.input_bounds = input_bounds
+ self.embed_fn, input_dims = get_encoder(
+ 'hashgrid',
+ input_dim=3,
+ log2_hashmap_size=19,
+ desired_resolution=max_res,
+ num_levels=num_levels,
+ interpolation=interpolation)
+ net = (torch.nn.Linear(input_dims, internal_dims, bias=False), torch.nn.ReLU())
+ for i in range(hidden-1):
+ net = net + (torch.nn.Linear(internal_dims, internal_dims, bias=False), torch.nn.ReLU())
+ net = net + (torch.nn.Linear(internal_dims, output_dims, bias=False),)
+ self.net = torch.nn.Sequential(*net)
+
+ def gradient(self, p):
+ p.requires_grad_(True)
+ if self.input_bounds is not None:
+ x = (p - self.input_bounds[0]) / (self.input_bounds[1] - self.input_bounds[0])
+ else:
+ x = p
+ if self.embed_fn is not None:
+ x = self.embed_fn(x)
+ y = self.net(x)
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
+ gradients = torch.autograd.grad(
+ outputs=y,
+ inputs=p,
+ grad_outputs=d_output,
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+ return gradients.unsqueeze(1)
+
+ def forward(self, p):
+ if self.input_bounds is not None:
+ p = (p - self.input_bounds[0]) / (self.input_bounds[1] - self.input_bounds[0]) * 2 -1
+ if self.embed_fn is not None:
+ p = self.embed_fn(p)
+ out = self.net(p)
+ return out
diff --git a/core/lib/obj.py b/core/lib/obj.py
new file mode 100755
index 0000000..7be7295
--- /dev/null
+++ b/core/lib/obj.py
@@ -0,0 +1,381 @@
+import os
+import cv2
+import torch
+import numpy as np
+
+
+def dot(x, y):
+ return torch.sum(x * y, -1, keepdim=True)
+
+
+def length(x, eps=1e-20):
+ return torch.sqrt(torch.clamp(dot(x, x), min=eps))
+
+
+def safe_normalize(x, eps=1e-20):
+ return x / length(x, eps)
+
+
+def keep_largest(mesh):
+ mesh_lst = mesh.split(only_watertight=False)
+ keep_mesh = mesh_lst[0]
+ for mesh in mesh_lst:
+ if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]:
+ keep_mesh = mesh
+ return keep_mesh
+
+def poisson(mesh, depth=10, face_count=500000):
+ import open3d as o3d
+ import trimesh
+ pcd_path = "/tmp/_soups.ply"
+ assert (mesh.vertex_normals.shape[1] == 3)
+ mesh.export(pcd_path)
+ pcl = o3d.io.read_point_cloud(pcd_path)
+ with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm:
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
+ pcl, depth=depth, n_threads=-1
+ )
+
+ mesh = trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles))
+
+ # only keep the largest component
+ largest_mesh = keep_largest(mesh)
+
+ return largest_mesh
+
+class Mesh():
+
+ def __init__(self, v=None, f=None, vn=None, fn=None, vt=None, ft=None, albedo=None, device=None, base=None, split=False):
+ if split:
+ import trimesh
+ mesh = trimesh.Trimesh(v.cpu().detach().numpy(), f.cpu().detach().numpy(), process=True, validate=True)
+ mesh = poisson(keep_largest(mesh))
+ v = v.new_tensor(mesh.vertices)
+ f = f.new_tensor(mesh.faces)
+ self.v = v
+ self.vn = vn
+ self.vt = vt
+ self.f = f
+ self.fn = fn
+ self.ft = ft
+ self.v_color = None
+ self.use_vertex_tex = False
+ self.ref_v = None
+ # only support a single albedo
+ self.albedo = albedo
+ self.device = device
+ # copy non-None attribute from base
+ if isinstance(base, Mesh):
+ for name in ['v', 'vn', 'vt', 'f', 'fn', 'ft', 'albedo']:
+ if getattr(self, name) is None:
+ setattr(self, name, getattr(base, name))
+
+ # load from obj file
+ @classmethod
+ def load_obj(cls, path, albedo_path=None, device=None, init_empty_tex=False, use_vertex_tex=False, albedo_res=2048, ref_path=None, keypoints_path=None, init_uv=True):
+ mesh = cls()
+
+ # device
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ if ref_path is not None:
+ import trimesh
+ mesh.ref_v = torch.tensor(trimesh.load(ref_path).vertices, dtype=torch.float32, device=device)
+ else:
+ mesh.ref_v = None
+
+
+ assert os.path.splitext(path)[-1] == '.obj' or os.path.splitext(path)[-1] == '.ply'
+
+
+
+ mesh.device = device
+
+ # try to find texture from mtl file
+ if albedo_path is None and '.obj' in path:
+ mtl_path = path.replace('.obj', '.mtl')
+ if os.path.exists(mtl_path):
+ with open(mtl_path, 'r') as f:
+ lines = f.readlines()
+ for line in lines:
+ split_line = line.split()
+ # empty line
+ if len(split_line) == 0:
+ continue
+ prefix = split_line[0]
+ # NOTE: simply use the first map_Kd as albedo!
+ if 'map_Kd' in prefix:
+ albedo_path = os.path.join(os.path.dirname(path), split_line[1])
+ print(f'[load_obj] use albedo from: {albedo_path}')
+ break
+
+ if init_empty_tex or albedo_path is None or not os.path.exists(albedo_path):
+ # init an empty texture
+ print(f'[load_obj] init empty albedo!')
+ # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
+ albedo = np.ones((albedo_res, albedo_res, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
+ else:
+ albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
+ albedo = albedo.astype(np.float32) / 255
+
+ # import matplotlib.pyplot as plt
+ # plt.imshow(albedo)
+ # plt.show()
+
+ mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
+
+ if os.path.splitext(path)[-1] == '.obj':
+
+ # load obj
+ with open(path, 'r') as f:
+ lines = f.readlines()
+
+ def parse_f_v(fv):
+ # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
+ # supported forms:
+ # f v1 v2 v3
+ # f v1/vt1 v2/vt2 v3/vt3
+ # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
+ # f v1//vn1 v2//vn2 v3//vn3
+ xs = [int(x) - 1 if x != '' else -1 for x in fv.split('/')]
+ xs.extend([-1] * (3 - len(xs)))
+ return xs[0], xs[1], xs[2]
+
+ # NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl)
+ vertices, texcoords, normals = [], [], []
+ faces, tfaces, nfaces = [], [], []
+ for line in lines:
+ split_line = line.split()
+ # empty line
+ if len(split_line) == 0:
+ continue
+ # v/vn/vt
+ prefix = split_line[0].lower()
+ if prefix == 'v':
+ vertices.append([float(v) for v in split_line[1:]])
+ elif prefix == 'vn':
+ normals.append([float(v) for v in split_line[1:]])
+ elif prefix == 'vt':
+ val = [float(v) for v in split_line[1:]]
+ texcoords.append([val[0], 1.0 - val[1]])
+ elif prefix == 'f':
+ vs = split_line[1:]
+ nv = len(vs)
+ v0, t0, n0 = parse_f_v(vs[0])
+ for i in range(nv - 2): # triangulate (assume vertices are ordered)
+ v1, t1, n1 = parse_f_v(vs[i + 1])
+ v2, t2, n2 = parse_f_v(vs[i + 2])
+ faces.append([v0, v1, v2])
+ tfaces.append([t0, t1, t2])
+ nfaces.append([n0, n1, n2])
+ elif os.path.splitext(path)[-1] == '.ply':
+ vertices, texcoords, normals = [], [], []
+ faces, tfaces, nfaces = [], [], []
+ import trimesh
+ trimesh_mesh = trimesh.load(path)
+ vertices = trimesh_mesh.vertices
+ faces = trimesh_mesh.faces
+ if isinstance(trimesh_mesh.visual, trimesh.visual.ColorVisuals):
+ vertices_colors = np.array(trimesh_mesh.visual.vertex_colors[:, :3]/255)
+ vertices = np.concatenate([vertices, vertices_colors], axis=-1)
+
+
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
+ mesh.vt = torch.tensor(texcoords, dtype=torch.float32, device=device) if len(texcoords) > 0 else None
+ mesh.vn = torch.tensor(normals, dtype=torch.float32, device=device) if len(normals) > 0 else None
+ mesh.use_vertex_tex = use_vertex_tex
+ if mesh.v.shape[1] == 6:
+ mesh.v_color = mesh.v[:, 3:]
+ mesh.v = mesh.v[:, :3]
+ elif mesh.use_vertex_tex:
+ mesh.v_color = torch.ones_like(mesh.v) * 0.5
+ else:
+ mesh.v_color = None
+
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
+ mesh.ft = torch.tensor(tfaces, dtype=torch.int32, device=device) if texcoords is not None else None
+ mesh.fn = torch.tensor(nfaces, dtype=torch.int32, device=device) if normals is not None else None
+
+ if keypoints_path is not None:
+ mesh.keypoints = np.load(keypoints_path, allow_pickle=True).item()['joints'].to(device)
+ if len(mesh.keypoints.shape) == 2:
+ mesh.keypoints = mesh.keypoints[None]
+ elif len(mesh.v) == 6890: # SMPL mesh init
+ import json
+ with open('smpl_vert_segmentation.json') as f:
+ segmentation = json.load(f)
+ head_ind = segmentation['head']
+ mesh.keypoints = mesh.v[head_ind].mean(dim=0)[None, None]
+ elif mesh.ref_v is not None and len(mesh.ref_v) == 6890: # SMPL mesh init
+ import json
+ with open('smpl_vert_segmentation.json', 'r') as f:
+ segmentation = json.load(f)
+ head_ind = segmentation['head']
+ mesh.keypoints = mesh.ref_v[head_ind].mean(dim=0)[None, None]
+ else:
+ mesh.keypoints = None
+ print('mesh keypoints', mesh.keypoints.shape)
+
+ # auto-normalize
+ mesh.auto_size()
+
+ print(f'[load_obj] v: {mesh.v.shape}, f: {mesh.f.shape}')
+
+ # auto-fix normal
+ if mesh.vn is None:
+ mesh.auto_normal()
+
+ print(f'[load_obj] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}')
+
+ # auto-fix texture
+ if mesh.vt is None and not use_vertex_tex and init_uv:
+ mesh.auto_uv(cache_path=path)
+
+ print(f'[load_obj] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}')
+
+ return mesh
+
+ # aabb
+ def aabb(self):
+ if hasattr(self, 'ref_v') and self.ref_v is not None:
+ return torch.min(self.ref_v, dim=0).values, torch.max(self.ref_v, dim=0).values
+ return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
+
+ # unit size
+ @torch.no_grad()
+ def auto_size(self): # to [-0.5, 0.5]
+ vmin, vmax = self.aabb()
+ scale = 1 / torch.max(vmax - vmin).item()
+ self.v = self.v - (vmax + vmin) / 2 # Center mesh on origin
+ v_c = (vmax + vmin) / 2
+ self.v = self.v * scale
+ if hasattr(self, 'keypoints') and self.keypoints is not None:
+ self.keypoints = (self.keypoints - (vmax + vmin) / 2)*scale
+ if hasattr(self, 'ref_v') and self.ref_v is not None:
+ self.ref_v = (self.ref_v - (vmax + vmin) / 2)*scale
+ self.resize_matrix_inv = torch.tensor([
+ [1/scale, 0, 0, v_c[0]],
+ [0, 1/scale, 0, v_c[1]],
+ [0, 0, 1/scale, v_c[2]],
+ [0, 0, 0, 1],
+ ], dtype=torch.float, device=self.device)
+
+ def auto_normal(self):
+ i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
+ v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
+ face_normals = torch.cross(v1 - v0, v2 - v0)
+
+ # Splat face normals to vertices
+ vn = torch.zeros_like(self.v)
+ vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
+ vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
+ vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
+
+ # Normalize, replace zero (degenerated) normals with some default value
+ vn = torch.where(dot(vn, vn) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
+ vn = safe_normalize(vn)
+ #print('self.v.grad: {} face_normals: {} vn: {}'.format(self.v.requires_grad, face_normals.requires_grad, vn.requires_grad))
+
+ self.vn = vn
+ self.fn = self.f
+
+ def auto_uv(self, cache_path=None):
+ print('[INFO] Using atlas to calculate UV. It takes 10~20min.')
+ # try to load cache
+ if cache_path is not None:
+ cache_path = cache_path.replace('.obj', '_uv.npz')
+ if cache_path and os.path.exists(cache_path):
+ data = np.load(cache_path)
+ vt_np, ft_np = data['vt'], data['ft']
+ else:
+
+ import xatlas
+ v_np = self.v.cpu().numpy() * 100
+ f_np = self.f.int().cpu().numpy()
+ atlas = xatlas.Atlas()
+ atlas.add_mesh(v_np, f_np)
+ chart_options = xatlas.ChartOptions()
+ chart_options.max_iterations = 4
+ atlas.generate(chart_options=chart_options)
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
+
+ # save to cache
+ if cache_path:
+ np.savez(cache_path, vt=vt_np, ft=ft_np)
+
+ vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
+ ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
+
+ self.vt = vt
+ self.ft = ft
+
+ def to(self, device):
+ self.device = device
+ for name in ['v', 'f', 'vn', 'fn', 'vt', 'ft', 'albedo']:
+ tensor = getattr(self, name)
+ if tensor is not None:
+ setattr(self, name, tensor.to(device))
+ return self
+
+ # write to obj file
+ def write(self, path):
+
+ mtl_path = path.replace('.obj', '.mtl')
+ albedo_path = path.replace('.obj', '_albedo.png')
+ v_np = self.v.cpu().numpy()
+ vt_np = self.vt.cpu().numpy() if self.vt is not None else None
+ vn_np = self.vn.cpu().numpy() if self.vn is not None else None
+ f_np = self.f.cpu().numpy()
+ ft_np = self.ft.cpu().numpy() if self.ft is not None else None
+ fn_np = self.fn.cpu().numpy() if self.fn is not None else None
+ vc_np = self.v_color.cpu().numpy() if self.v_color is not None else None
+ print(f'vertice num: {len(v_np)}, face num: {len(f_np)}')
+
+ with open(path, "w") as fp:
+ fp.write(f'mtllib {os.path.basename(mtl_path)} \n')
+ if self.use_vertex_tex:
+ for v, c in zip(v_np, vc_np):
+ fp.write(f'v {v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]}\n')
+ else:
+ for v in v_np:
+ fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
+ if vt_np is not None:
+ for v in vt_np:
+ fp.write(f'vt {v[0]} {1 - v[1]} \n')
+ if vn_np is not None:
+ for v in vn_np:
+ fp.write(f'vn {v[0]} {v[1]} {v[2]} \n')
+ if vt_np is not None:
+ fp.write(f'usemtl defaultMat \n')
+ for i in range(len(f_np)):
+ fp.write(
+ f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
+ {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
+ {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
+ )
+ else:
+ for i in range(len(f_np)):
+ fp.write(
+ f'f {f_np[i, 0] + 1} \
+ {f_np[i, 1] + 1} \
+ {f_np[i, 2] + 1} \n'
+ )
+
+
+ if vt_np is not None:
+ with open(mtl_path, "w") as fp:
+ fp.write(f'newmtl defaultMat \n')
+ fp.write(f'Ka 1 1 1 \n')
+ fp.write(f'Kd 1 1 1 \n')
+ fp.write(f'Ks 0 0 0 \n')
+ fp.write(f'Tr 1 \n')
+ fp.write(f'illum 1 \n')
+ fp.write(f'Ns 0 \n')
+ if not self.use_vertex_tex:
+ fp.write(f'map_Kd {os.path.basename(albedo_path)} \n')
+
+ albedo = self.albedo.cpu().numpy()
+ albedo = (albedo * 255).astype(np.uint8)
+ cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
diff --git a/core/lib/optimizer.py b/core/lib/optimizer.py
new file mode 100755
index 0000000..f5bb64f
--- /dev/null
+++ b/core/lib/optimizer.py
@@ -0,0 +1,325 @@
+# Copyright 2022 Garena Online Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List
+
+import torch
+from torch import Tensor
+from torch.optim.optimizer import Optimizer
+
+
+class Adan(Optimizer):
+ """
+ Implements a pytorch variant of Adan
+ Adan was proposed in
+ Adan: Adaptive Nesterov Momentum Algorithm for
+ Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022.
+ https://arxiv.org/abs/2208.06677
+ Arguments:
+ params (iterable): iterable of parameters to optimize or
+ dicts defining parameter groups.
+ lr (float, optional): learning rate. (default: 1e-3)
+ betas (Tuple[float, float, flot], optional): coefficients used for
+ first- and second-order moments. (default: (0.98, 0.92, 0.99))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability. (default: 1e-8)
+ weight_decay (float, optional): decoupled weight decay
+ (L2 penalty) (default: 0)
+ max_grad_norm (float, optional): value used to clip
+ global grad norm (default: 0.0 no clip)
+ no_prox (bool): how to perform the decoupled weight decay
+ (default: False)
+ foreach (bool): if True would use torch._foreach implementation.
+ It's faster but uses slightly more memory. (default: True)
+ """
+ def __init__(self,
+ params,
+ lr=1e-3,
+ betas=(0.98, 0.92, 0.99),
+ eps=1e-8,
+ weight_decay=0.0,
+ max_grad_norm=0.0,
+ no_prox=False,
+ foreach: bool = True):
+ if not 0.0 <= max_grad_norm:
+ raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm))
+ if not 0.0 <= lr:
+ raise ValueError('Invalid learning rate: {}'.format(lr))
+ if not 0.0 <= eps:
+ raise ValueError('Invalid epsilon value: {}'.format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError('Invalid beta parameter at index 0: {}'.format(
+ betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError('Invalid beta parameter at index 1: {}'.format(
+ betas[1]))
+ if not 0.0 <= betas[2] < 1.0:
+ raise ValueError('Invalid beta parameter at index 2: {}'.format(
+ betas[2]))
+ defaults = dict(lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ max_grad_norm=max_grad_norm,
+ no_prox=no_prox,
+ foreach=foreach)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Adan, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('no_prox', False)
+
+ @torch.no_grad()
+ def restart_opt(self):
+ for group in self.param_groups:
+ group['step'] = 0
+ for p in group['params']:
+ if p.requires_grad:
+ state = self.state[p]
+ # State initialization
+
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ # Exponential moving average of gradient difference
+ state['exp_avg_diff'] = torch.zeros_like(p)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step."""
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ if self.defaults['max_grad_norm'] > 0:
+ device = self.param_groups[0]['params'][0].device
+ global_grad_norm = torch.zeros(1, device=device)
+
+ max_grad_norm = torch.tensor(self.defaults['max_grad_norm'],
+ device=device)
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is not None:
+ grad = p.grad
+ global_grad_norm.add_(grad.pow(2).sum())
+
+ global_grad_norm = torch.sqrt(global_grad_norm)
+
+ clip_global_grad_norm = torch.clamp(
+ max_grad_norm / (global_grad_norm + group['eps']),
+ max=1.0).item()
+ else:
+ clip_global_grad_norm = 1.0
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ exp_avg_diffs = []
+ neg_pre_grads = []
+
+ beta1, beta2, beta3 = group['betas']
+ # assume same step across group now to simplify things
+ # per parameter step can be easily support
+ # by making it tensor, or pass list into kernel
+ if 'step' in group:
+ group['step'] += 1
+ else:
+ group['step'] = 1
+
+ bias_correction1 = 1.0 - beta1**group['step']
+ bias_correction2 = 1.0 - beta2**group['step']
+ bias_correction3 = 1.0 - beta3**group['step']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ grads.append(p.grad)
+
+ state = self.state[p]
+ if len(state) == 0:
+ state['exp_avg'] = torch.zeros_like(p)
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ state['exp_avg_diff'] = torch.zeros_like(p)
+
+ if 'neg_pre_grad' not in state or group['step'] == 1:
+ state['neg_pre_grad'] = p.grad.clone().mul_(
+ -clip_global_grad_norm)
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ exp_avg_diffs.append(state['exp_avg_diff'])
+ neg_pre_grads.append(state['neg_pre_grad'])
+
+ kwargs = dict(
+ params=params_with_grad,
+ grads=grads,
+ exp_avgs=exp_avgs,
+ exp_avg_sqs=exp_avg_sqs,
+ exp_avg_diffs=exp_avg_diffs,
+ neg_pre_grads=neg_pre_grads,
+ beta1=beta1,
+ beta2=beta2,
+ beta3=beta3,
+ bias_correction1=bias_correction1,
+ bias_correction2=bias_correction2,
+ bias_correction3_sqrt=math.sqrt(bias_correction3),
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ no_prox=group['no_prox'],
+ clip_global_grad_norm=clip_global_grad_norm,
+ )
+
+ if group['foreach']:
+ _multi_tensor_adan(**kwargs)
+ else:
+ _single_tensor_adan(**kwargs)
+
+ return loss
+
+
+def _single_tensor_adan(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ exp_avg_diffs: List[Tensor],
+ neg_pre_grads: List[Tensor],
+ *,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ bias_correction1: float,
+ bias_correction2: float,
+ bias_correction3_sqrt: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+ no_prox: bool,
+ clip_global_grad_norm: Tensor,
+):
+ for i, param in enumerate(params):
+ grad = grads[i]
+ exp_avg = exp_avgs[i]
+ exp_avg_sq = exp_avg_sqs[i]
+ exp_avg_diff = exp_avg_diffs[i]
+ neg_grad_or_diff = neg_pre_grads[i]
+
+ grad.mul_(clip_global_grad_norm)
+
+ # for memory saving, we use `neg_grad_or_diff`
+ # to get some temp variable in a inplace way
+ neg_grad_or_diff.add_(grad)
+
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t
+ exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff,
+ alpha=1 - beta2) # diff_t
+
+ neg_grad_or_diff.mul_(beta2).add_(grad)
+ exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff,
+ neg_grad_or_diff,
+ value=1 - beta3) # n_t
+
+ denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps)
+ step_size_diff = lr * beta2 / bias_correction2
+ step_size = lr / bias_correction1
+
+ if no_prox:
+ param.mul_(1 - lr * weight_decay)
+ param.addcdiv_(exp_avg, denom, value=-step_size)
+ param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
+ else:
+ param.addcdiv_(exp_avg, denom, value=-step_size)
+ param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
+ param.div_(1 + lr * weight_decay)
+
+ neg_grad_or_diff.zero_().add_(grad, alpha=-1.0)
+
+
+def _multi_tensor_adan(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ exp_avg_diffs: List[Tensor],
+ neg_pre_grads: List[Tensor],
+ *,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ bias_correction1: float,
+ bias_correction2: float,
+ bias_correction3_sqrt: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+ no_prox: bool,
+ clip_global_grad_norm: Tensor,
+):
+ if len(params) == 0:
+ return
+
+ torch._foreach_mul_(grads, clip_global_grad_norm)
+
+ # for memory saving, we use `neg_pre_grads`
+ # to get some temp variable in a inplace way
+ torch._foreach_add_(neg_pre_grads, grads)
+
+ torch._foreach_mul_(exp_avgs, beta1)
+ torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t
+
+ torch._foreach_mul_(exp_avg_diffs, beta2)
+ torch._foreach_add_(exp_avg_diffs, neg_pre_grads,
+ alpha=1 - beta2) # diff_t
+
+ torch._foreach_mul_(neg_pre_grads, beta2)
+ torch._foreach_add_(neg_pre_grads, grads)
+ torch._foreach_mul_(exp_avg_sqs, beta3)
+ torch._foreach_addcmul_(exp_avg_sqs,
+ neg_pre_grads,
+ neg_pre_grads,
+ value=1 - beta3) # n_t
+
+ denom = torch._foreach_sqrt(exp_avg_sqs)
+ torch._foreach_div_(denom, bias_correction3_sqrt)
+ torch._foreach_add_(denom, eps)
+
+ step_size_diff = lr * beta2 / bias_correction2
+ step_size = lr / bias_correction1
+
+ if no_prox:
+ torch._foreach_mul_(params, 1 - lr * weight_decay)
+ torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
+ torch._foreach_addcdiv_(params,
+ exp_avg_diffs,
+ denom,
+ value=-step_size_diff)
+ else:
+ torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
+ torch._foreach_addcdiv_(params,
+ exp_avg_diffs,
+ denom,
+ value=-step_size_diff)
+ torch._foreach_div_(params, 1 + lr * weight_decay)
+ torch._foreach_zero_(neg_pre_grads)
+ torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0)
\ No newline at end of file
diff --git a/core/lib/pose_utils.py b/core/lib/pose_utils.py
new file mode 100755
index 0000000..81da643
--- /dev/null
+++ b/core/lib/pose_utils.py
@@ -0,0 +1,552 @@
+eps = 0.01
+JOINT_NAMES = [
+ "pelvis",
+ "left_hip",
+ "right_hip",
+ "spine1",
+ "left_knee",
+ "right_knee",
+ "spine2",
+ "left_ankle",
+ "right_ankle",
+ "spine3",
+ "left_foot",
+ "right_foot",
+ "neck",
+ "left_collar",
+ "right_collar",
+ "head",
+ "left_shoulder",
+ "right_shoulder",
+ "left_elbow",
+ "right_elbow",
+ "left_wrist",
+ "right_wrist",
+ "jaw"
+]
+
+SMPLX_NAMES = [
+ "pelvis",
+ "left_hip",
+ "right_hip",
+ "spine1",
+ "left_knee",
+ "right_knee",
+ "spine2",
+ "left_ankle",
+ "right_ankle",
+ "spine3",
+ "left_foot",
+ "right_foot",
+ "neck",
+ "left_collar",
+ "right_collar",
+ "head",
+ "left_shoulder",
+ "right_shoulder",
+ "left_elbow",
+ "right_elbow",
+ "left_wrist",
+ "right_wrist",
+ "jaw",
+ "left_eye_smplx",
+ "right_eye_smplx",
+ "left_index1",
+ "left_index2",
+ "left_index3",
+ "left_middle1",
+ "left_middle2",
+ "left_middle3",
+ "left_pinky1",
+ "left_pinky2",
+ "left_pinky3",
+ "left_ring1",
+ "left_ring2",
+ "left_ring3",
+ "left_thumb1",
+ "left_thumb2",
+ "left_thumb3",
+ "right_index1",
+ "right_index2",
+ "right_index3",
+ "right_middle1",
+ "right_middle2",
+ "right_middle3",
+ "right_pinky1",
+ "right_pinky2",
+ "right_pinky3",
+ "right_ring1",
+ "right_ring2",
+ "right_ring3",
+ "right_thumb1",
+ "right_thumb2",
+ "right_thumb3",
+ "right_eye_brow1",
+ "right_eye_brow2",
+ "right_eye_brow3",
+ "right_eye_brow4",
+ "right_eye_brow5",
+ "left_eye_brow5",
+ "left_eye_brow4",
+ "left_eye_brow3",
+ "left_eye_brow2",
+ "left_eye_brow1",
+ "nose1",
+ "nose2",
+ "nose3",
+ "nose4",
+ "right_nose_2",
+ "right_nose_1",
+ "nose_middle",
+ "left_nose_1",
+ "left_nose_2",
+ "right_eye1",
+ "right_eye2",
+ "right_eye3",
+ "right_eye4",
+ "right_eye5",
+ "right_eye6",
+ "left_eye4",
+ "left_eye3",
+ "left_eye2",
+ "left_eye1",
+ "left_eye6",
+ "left_eye5",
+ "right_mouth_1",
+ "right_mouth_2",
+ "right_mouth_3",
+ "mouth_top",
+ "left_mouth_3",
+ "left_mouth_2",
+ "left_mouth_1",
+ "left_mouth_5",
+ "left_mouth_4",
+ "mouth_bottom",
+ "right_mouth_4",
+ "right_mouth_5",
+ "right_lip_1",
+ "right_lip_2",
+ "lip_top",
+ "left_lip_2",
+ "left_lip_1",
+ "left_lip_3",
+ "lip_bottom",
+ "right_lip_3",
+ "right_contour_1",
+ "right_contour_2",
+ "right_contour_3",
+ "right_contour_4",
+ "right_contour_5",
+ "right_contour_6",
+ "right_contour_7",
+ "right_contour_8",
+ "contour_middle",
+ "left_contour_8",
+ "left_contour_7",
+ "left_contour_6",
+ "left_contour_5",
+ "left_contour_4",
+ "left_contour_3",
+ "left_contour_2",
+ "left_contour_1",
+ "head_top",
+ "left_big_toe",
+ "left_ear",
+ "left_eye",
+ "left_heel",
+ "left_index",
+ "left_middle",
+ "left_pinky",
+ "left_ring",
+ "left_small_toe",
+ "left_thumb",
+ "nose",
+ "right_big_toe",
+ "right_ear",
+ "right_eye",
+ "right_heel",
+ "right_index",
+ "right_middle",
+ "right_pinky",
+ "right_ring",
+ "right_small_toe",
+ "right_thumb",
+]
+
+OPENPOSE_NAMES = [
+ "nose",
+ "neck",
+ "right_shoulder",
+ "right_elbow",
+ "right_wrist",
+ "left_shoulder",
+ "left_elbow",
+ "left_wrist",
+ "pelvis",
+ "right_hip",
+ "right_knee",
+ "right_ankle",
+ "left_hip",
+ "left_knee",
+ "left_ankle",
+ "right_eye",
+ "left_eye",
+ "right_ear",
+ "left_ear",
+ "left_big_toe",
+ "left_small_toe",
+ "left_heel",
+ "right_big_toe",
+ "right_small_toe",
+ "right_heel",
+ "left_wrist",
+ "left_thumb1",
+ "left_thumb2",
+ "left_thumb3",
+ "left_thumb",
+ "left_index1",
+ "left_index2",
+ "left_index3",
+ "left_index",
+ "left_middle1",
+ "left_middle2",
+ "left_middle3",
+ "left_middle",
+ "left_ring1",
+ "left_ring2",
+ "left_ring3",
+ "left_ring",
+ "left_pinky1",
+ "left_pinky2",
+ "left_pinky3",
+ "left_pinky",
+ "right_wrist",
+ "right_thumb1",
+ "right_thumb2",
+ "right_thumb3",
+ "right_thumb",
+ "right_index1",
+ "right_index2",
+ "right_index3",
+ "right_index",
+ "right_middle1",
+ "right_middle2",
+ "right_middle3",
+ "right_middle",
+ "right_ring1",
+ "right_ring2",
+ "right_ring3",
+ "right_ring",
+ "right_pinky1",
+ "right_pinky2",
+ "right_pinky3",
+ "right_pinky",
+ "right_eye_brow1",
+ "right_eye_brow2",
+ "right_eye_brow3",
+ "right_eye_brow4",
+ "right_eye_brow5",
+ "left_eye_brow5",
+ "left_eye_brow4",
+ "left_eye_brow3",
+ "left_eye_brow2",
+ "left_eye_brow1",
+ "nose1",
+ "nose2",
+ "nose3",
+ "nose4",
+ "right_nose_2",
+ "right_nose_1",
+ "nose_middle",
+ "left_nose_1",
+ "left_nose_2",
+ "right_eye1",
+ "right_eye2",
+ "right_eye3",
+ "right_eye4",
+ "right_eye5",
+ "right_eye6",
+ "left_eye4",
+ "left_eye3",
+ "left_eye2",
+ "left_eye1",
+ "left_eye6",
+ "left_eye5",
+ "right_mouth_1",
+ "right_mouth_2",
+ "right_mouth_3",
+ "mouth_top",
+ "left_mouth_3",
+ "left_mouth_2",
+ "left_mouth_1",
+ "left_mouth_5",
+ "left_mouth_4",
+ "mouth_bottom",
+ "right_mouth_4",
+ "right_mouth_5",
+ "right_lip_1",
+ "right_lip_2",
+ "lip_top",
+ "left_lip_2",
+ "left_lip_1",
+ "left_lip_3",
+ "lip_bottom",
+ "right_lip_3",
+ "right_contour_1",
+ "right_contour_2",
+ "right_contour_3",
+ "right_contour_4",
+ "right_contour_5",
+ "right_contour_6",
+ "right_contour_7",
+ "right_contour_8",
+ "contour_middle",
+ "left_contour_8",
+ "left_contour_7",
+ "left_contour_6",
+ "left_contour_5",
+ "left_contour_4",
+ "left_contour_3",
+ "left_contour_2",
+ "left_contour_1"
+]
+
+OPENPOSE_BODY = [
+ "nose",
+ "neck",
+ "right_shoulder",
+ "right_elbow",
+ "right_wrist",
+ "left_shoulder",
+ "left_elbow",
+ "left_wrist",
+ "right_hip",
+ "right_knee",
+ "right_ankle",
+ "left_hip",
+ "left_knee",
+ "left_ankle",
+ "right_eye",
+ "left_eye",
+ "right_ear",
+ "left_ear",
+]
+
+OPENPOSE_LEFT_HAND = [
+ "left_wrist",
+ "left_thumb1",
+ "left_thumb2",
+ "left_thumb3",
+ "left_thumb",
+ "left_index1",
+ "left_index2",
+ "left_index3",
+ "left_index",
+ "left_middle1",
+ "left_middle2",
+ "left_middle3",
+ "left_middle",
+ "left_ring1",
+ "left_ring2",
+ "left_ring3",
+ "left_ring",
+ "left_pinky1",
+ "left_pinky2",
+ "left_pinky3",
+ "left_pinky",
+]
+
+OPENPOSE_RIGHT_HAND = [
+ "right_wrist",
+ "right_thumb1",
+ "right_thumb2",
+ "right_thumb3",
+ "right_thumb",
+ "right_index1",
+ "right_index2",
+ "right_index3",
+ "right_index",
+ "right_middle1",
+ "right_middle2",
+ "right_middle3",
+ "right_middle",
+ "right_ring1",
+ "right_ring2",
+ "right_ring3",
+ "right_ring",
+ "right_pinky1",
+ "right_pinky2",
+ "right_pinky3",
+ "right_pinky",
+]
+
+OPENPOSE_FACE = [
+ "right_eye_brow1",
+ "right_eye_brow2",
+ "right_eye_brow3",
+ "right_eye_brow4",
+ "right_eye_brow5",
+ "left_eye_brow5",
+ "left_eye_brow4",
+ "left_eye_brow3",
+ "left_eye_brow2",
+ "left_eye_brow1",
+ "nose1",
+ "nose2",
+ "nose3",
+ "nose4",
+ "right_nose_2",
+ "right_nose_1",
+ "nose_middle",
+ "left_nose_1",
+ "left_nose_2",
+ "right_eye1",
+ "right_eye2",
+ "right_eye3",
+ "right_eye4",
+ "right_eye5",
+ "right_eye6",
+ "left_eye4",
+ "left_eye3",
+ "left_eye2",
+ "left_eye1",
+ "left_eye6",
+ "left_eye5",
+ "right_mouth_1",
+ "right_mouth_2",
+ "right_mouth_3",
+ "mouth_top",
+ "left_mouth_3",
+ "left_mouth_2",
+ "left_mouth_1",
+ "left_mouth_5",
+ "left_mouth_4",
+ "mouth_bottom",
+ "right_mouth_4",
+ "right_mouth_5",
+ "right_lip_1",
+ "right_lip_2",
+ "lip_top",
+ "left_lip_2",
+ "left_lip_1",
+ "left_lip_3",
+ "lip_bottom",
+ "right_lip_3",
+ "right_contour_1",
+ "right_contour_2",
+ "right_contour_3",
+ "right_contour_4",
+ "right_contour_5",
+ "right_contour_6",
+ "right_contour_7",
+ "right_contour_8",
+ "contour_middle",
+ "left_contour_8",
+ "left_contour_7",
+ "left_contour_6",
+ "left_contour_5",
+ "left_contour_4",
+ "left_contour_3",
+ "left_contour_2",
+ "left_contour_1"
+]
+
+import cv2
+import numpy as np
+import math
+import matplotlib
+
+def draw_bodypose(canvas, candidate):
+ H, W, C = canvas.shape
+ candidate = np.array(candidate)
+
+ stickwidth = 4
+
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+ for i in range(17):
+ index = np.array(limbSeq[i]) - 1
+ Y = candidate[index.astype(int), 0] * float(W)
+ X = candidate[index.astype(int), 1] * float(H)
+ if X[0] < eps or X[1] < eps:
+ continue
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
+
+ canvas = (canvas * 0.6).astype(np.uint8)
+
+ for i in range(18):
+ x, y = candidate[i][0:2]
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
+
+ return canvas
+
+
+def draw_handpose(canvas, peaks):
+ H, W, C = canvas.shape
+
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+ peaks = np.array(peaks)
+
+ for ie, e in enumerate(edges):
+ x1, y1 = peaks[e[0]]
+ x2, y2 = peaks[e[1]]
+ x1 = int(x1 * W)
+ y1 = int(y1 * H)
+ x2 = int(x2 * W)
+ y2 = int(y2 * H)
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
+
+ for i, keyponit in enumerate(peaks):
+ x, y = keyponit
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+ return canvas
+
+
+def draw_facepose(canvas, lmks):
+ H, W, C = canvas.shape
+ lmks = np.array(lmks)
+ for lmk in lmks:
+ x, y = lmk
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
+ return canvas
+
+def draw_openpose_map(canvas, smplx_keypoints_2d, smplx_keypoints_mask):
+ assert len(SMPLX_NAMES) == len(smplx_keypoints_2d)
+ kp_dict = dict()
+ for k, p, b in zip(SMPLX_NAMES, smplx_keypoints_2d, smplx_keypoints_mask):
+ b = b or k in JOINT_NAMES
+ if not b:
+ p *= 0
+ kp_dict[k] = p
+ body_points = [kp_dict[k] for k in OPENPOSE_BODY]
+ left_hand_points = [kp_dict[k] for k in OPENPOSE_LEFT_HAND]
+ right_hand_points = [kp_dict[k] for k in OPENPOSE_RIGHT_HAND]
+ face_points = [kp_dict[k] for k in OPENPOSE_FACE]
+ draw_bodypose(canvas, body_points)
+ draw_handpose(canvas, left_hand_points)
+ draw_handpose(canvas, right_hand_points)
+ draw_facepose(canvas, face_points)
+ return canvas
\ No newline at end of file
diff --git a/core/lib/provider.py b/core/lib/provider.py
new file mode 100755
index 0000000..514881e
--- /dev/null
+++ b/core/lib/provider.py
@@ -0,0 +1,322 @@
+import os
+import cv2
+import glob
+import json
+import tqdm
+import random
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from .camera_utils import *
+
+
+def get_view_direction(thetas, phis, overhead, front, phi_diff=0):
+ # phis [B,]; thetas: [B,]
+ # front = 0 [0, front)
+ # side (left) = 1 [front, 180)
+ # back = 2 [180, 180+front)
+ # side (right) = 3 [180+front, 360)
+ # top = 4 [0, overhead]
+ # bottom = 5 [180-overhead, 180]
+ phis += phi_diff / 180 * np.pi
+ phis[phis >= 2 * np.pi] -= np.pi * 2
+ phis[phis < 0] += np.pi * 2
+ res = torch.zeros(thetas.shape[0], dtype=torch.long)
+ # first determine by phis
+ res[(phis < front)] = 0
+ res[(phis >= front) & (phis < np.pi)] = 1
+ res[(phis >= np.pi) & (phis < (np.pi + front))] = 2
+ res[(phis >= (np.pi + front))] = 3
+ # override by thetas
+ res[thetas <= overhead] = 4
+ res[thetas >= (np.pi - overhead)] = 5
+ return res
+
+
+
+def rand_poses(size,
+ device,
+ radius_range=[1, 1.5],
+ theta_range=[0, 120],
+ phi_range=[0, 360],
+ height_range=[0,0],
+ return_dirs=False,
+ angle_overhead=30,
+ angle_front=60,
+ jitter=False,
+ uniform_sphere_rate=0.5,
+ phi_diff=0,
+ center_offset=0.,
+ ):
+ ''' generate random poses from an orbit camera
+ Args:
+ size: batch size of generated poses.
+ device: where to allocate the output.
+ radius: camera radius
+ theta_range: [min, max], should be in [0, pi]
+ phi_range: [min, max], should be in [0, 2 * pi]
+ Return:
+ poses: [size, 4, 4]
+ '''
+
+ theta_range = np.deg2rad(theta_range)
+ phi_range = np.deg2rad(phi_range)
+ angle_overhead = np.deg2rad(angle_overhead)
+ angle_front = np.deg2rad(angle_front)
+
+ radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
+ if random.random() < uniform_sphere_rate:
+ unit_centers = F.normalize(
+ torch.stack([
+ (torch.rand(size, device=device) - 0.5) * 2.0,
+ torch.rand(size, device=device),
+ (torch.rand(size, device=device) - 0.5) * 2.0,
+ ],
+ dim=-1),
+ p=2,
+ dim=1)
+ thetas = torch.acos(unit_centers[:, 1])
+ phis = torch.atan2(unit_centers[:, 0], unit_centers[:, 2])
+ phis[phis < 0] += 2 * np.pi
+ centers = unit_centers * radius.unsqueeze(-1)
+ centers = centers + centers.new_tensor(center_offset)
+ else:
+ heights = torch.rand(size, device=device) * (height_range[1] - height_range[0]) + height_range[0]
+ thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
+ phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
+
+ centers = torch.stack([
+ radius * torch.sin(thetas) * torch.sin(phis),
+ radius * torch.cos(thetas) + heights,
+ radius * torch.sin(thetas) * torch.cos(phis),
+ ],
+ dim=-1) # [B, 3]
+ centers = centers + centers.new_tensor(center_offset)
+
+ targets = torch.zeros_like(centers) + centers.new_tensor(center_offset)
+ targets[:, 1] += heights
+
+ # jitters
+ if jitter:
+ centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
+ targets = targets + torch.randn_like(centers) * 0.2
+
+ # lookat
+ forward_vector = safe_normalize(centers - targets)
+ up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
+
+ if jitter:
+ up_noise = torch.randn_like(up_vector) * 0.02
+ else:
+ up_noise = 0
+
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
+
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
+ poses[:, :3, 3] = centers
+
+ if return_dirs:
+ dirs = get_view_direction(thetas, phis, angle_overhead, angle_front, phi_diff=phi_diff)
+ else:
+ dirs = None
+
+ return poses, dirs, radius
+
+
+def circle_poses(device, radius=1.25, theta=60, phi=0, return_dirs=False, angle_overhead=30, angle_front=60, phi_diff=0, height=0,
+ center_offset=0.,):
+
+ theta = np.deg2rad(theta)
+ phi = np.deg2rad(phi)
+ angle_overhead = np.deg2rad(angle_overhead)
+ angle_front = np.deg2rad(angle_front)
+
+ thetas = torch.FloatTensor([theta]).to(device)
+ phis = torch.FloatTensor([phi]).to(device)
+
+ centers = torch.stack([
+ radius * torch.sin(thetas) * torch.sin(phis),
+ radius * torch.cos(thetas) + height,
+ radius * torch.sin(thetas) * torch.cos(phis),
+ ],
+ dim=-1) # [B, 3]
+
+ centers = centers + centers.new_tensor(center_offset)
+
+ # lookat
+ targets = torch.zeros_like(centers) + centers.new_tensor(center_offset)
+ targets[:, 1] += height
+ forward_vector = safe_normalize(centers-targets)
+ up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0)
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
+
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0)
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
+ poses[:, :3, 3] = centers
+
+ if return_dirs:
+ dirs = get_view_direction(thetas, phis, angle_overhead, angle_front, phi_diff=phi_diff)
+ else:
+ dirs = None
+
+ return poses, dirs, radius
+
+
+class ViewDataset:
+
+ def __init__(self, cfg, device, type='train', H=256, W=256, size=100, render_head=False, render_canpose=False):
+ super().__init__()
+
+ self.cfg = cfg
+ self.device = device
+ self.type = type # train, val, test
+
+ self.H = H
+ self.W = W
+ self.size = size
+ self.num_frames = size
+ if render_head:
+ self.size = self.num_frames * 2
+
+ self.training = self.type in ['train', 'all']
+
+ self.cx = self.H / 2
+ self.cy = self.W / 2
+
+ self.near = self.cfg.model.min_near
+ self.far = 1000 # infinite
+
+ self.aspect = self.W / self.H
+ self.global_step = 0
+
+ def get_phi_range(self):
+ return self.cfg.train.phi_range
+
+ def update_global_step(self, global_step):
+ self.global_step = global_step
+
+ def collate(self, index):
+
+ B = len(index) # always 1
+ is_face = False
+ can_pose = False
+ if self.training:
+ if self.cfg.data.can_pose_folder is not None:
+ can_pose = random.random() < self.cfg.train.can_pose_sample_ratio
+ # random pose on the fly
+ if random.random() < self.cfg.train.face_sample_ratio:
+ poses, dirs, radius = rand_poses(
+ B,
+ self.device,
+ radius_range=self.cfg.train.face_radius_range,
+ return_dirs=self.cfg.guidance.use_view_prompt,
+ angle_overhead=self.cfg.train.angle_overhead,
+ angle_front=self.cfg.train.angle_front,
+ jitter=False,
+ uniform_sphere_rate=0.,
+ phi_diff=self.cfg.train.face_phi_diff,
+ theta_range=self.cfg.train.face_theta_range,
+ phi_range=self.cfg.train.face_phi_range,
+ height_range=self.cfg.train.face_height_range,
+ center_offset=np.array(self.cfg.train.head_position if not can_pose else self.cfg.train.canpose_head_position)
+ )
+ is_face = True
+ else:
+ poses, dirs, radius = rand_poses(
+ B,
+ self.device,
+ radius_range=self.cfg.train.radius_range,
+ return_dirs=self.cfg.guidance.use_view_prompt,
+ angle_overhead=self.cfg.train.angle_overhead,
+ angle_front=self.cfg.train.angle_front,
+ jitter=self.cfg.train.jitter_pose,
+ uniform_sphere_rate=0.,
+ phi_diff=self.cfg.train.phi_diff,
+ theta_range=self.cfg.train.theta_range,
+ phi_range=self.get_phi_range(),
+ height_range=self.cfg.train.height_range,
+ )
+ # random focal
+ fov = random.random() * (self.cfg.train.fovy_range[1] - self.cfg.train.fovy_range[0]) + self.cfg.train.fovy_range[0]
+ else:
+ # circle pose
+ phi = ((index[0] / self.num_frames) * 360)%360
+ if index[0] < self.num_frames:
+ poses, dirs, radius = circle_poses(
+ self.device,
+ radius=self.cfg.train.radius_range[1] * 0.9,
+ theta=90,
+ phi=phi,
+ return_dirs=self.cfg.guidance.use_view_prompt,
+ angle_overhead=self.cfg.train.angle_overhead,
+ angle_front=self.cfg.train.angle_front,
+ phi_diff=self.cfg.train.phi_diff
+ )
+ else:
+ is_face = True
+ poses, dirs, radius = circle_poses(
+ self.device,
+ radius=self.cfg.train.face_radius_range[0],
+ height=self.cfg.train.face_height_range[0],
+ theta=90,
+ phi=phi,
+ return_dirs=self.cfg.guidance.use_view_prompt,
+ angle_overhead=self.cfg.train.angle_overhead,
+ angle_front=self.cfg.train.angle_front,
+ phi_diff=self.cfg.train.phi_diff,
+ center_offset=np.array(self.cfg.train.head_position)
+ )
+
+ # fixed focal
+ fov = (self.cfg.train.fovy_range[1] + self.cfg.train.fovy_range[0]) / 2
+
+ focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
+ intrinsics = np.array([focal, focal, self.cx, self.cy])
+
+ projection = torch.tensor([
+ [2 * focal / self.W, 0, 0, 0],
+ [0, -2 * focal / self.H, 0, 0],
+ [0, 0, -(self.far + self.near)/(self.far - self.near), -(2 * self.far * self.near)/(self.far - self.near)],
+ [0, 0, -1, 0]
+ ], dtype=torch.float32, device=self.device).unsqueeze(0) # yapf: disabl
+ mvp = projection @ torch.inverse(poses.cpu()).to(self.device)
+ if not self.training:
+ if is_face or can_pose:
+ mvp = projection @ torch.inverse(poses.cpu()).to(self.device)
+ else:
+ mvp = torch.inverse(poses.cpu()).to(self.device)
+ mvp[0, 2, 3] = 0.
+ TO_WORLD = np.eye(
+ 4,
+ dtype=np.float32,
+ )
+ TO_WORLD[2,2] = -1
+ TO_WORLD[1,1] = -1
+ TO_WORLD = mvp.new_tensor(TO_WORLD)
+ mvp = TO_WORLD @ mvp
+
+ data = {
+ 'H': self.H,
+ 'W': self.W,
+ 'mvp': mvp[0], # [4, 4]
+ 'poses': poses, # [1, 4, 4]
+ 'intrinsics': intrinsics,
+ 'dir': dirs,
+ 'near_far': [self.near, self.far],
+ 'is_face': is_face,
+ 'radius': radius,
+ 'can_pose': can_pose
+ }
+
+ return data
+
+ def dataloader(self):
+ loader = DataLoader(
+ list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
+ return loader
\ No newline at end of file
diff --git a/core/lib/renderer.py b/core/lib/renderer.py
new file mode 100755
index 0000000..c582a2a
--- /dev/null
+++ b/core/lib/renderer.py
@@ -0,0 +1,607 @@
+import os
+import math
+import cv2
+import trimesh
+import numpy as np
+import random
+from pathlib import Path
+
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+import nvdiffrast.torch as dr
+# import kaolin as kal
+from .network_utils import get_encoder
+
+from .obj import Mesh, safe_normalize
+from .marching_tets import DMTet
+from .tet_utils import build_tet_grid
+from .dmtet_network import DMTetMesh
+from .color_network import ColorNetwork
+from .uv_utils import texture_padding
+from PIL import Image
+def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'):
+ assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (
+ x.shape[1] < size[0] and
+ x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
+ y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
+ if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
+ y = torch.nn.functional.interpolate(y, size, mode=min)
+ else: # Magnification
+ if mag == 'bilinear' or mag == 'bicubic':
+ y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
+ else:
+ y = torch.nn.functional.interpolate(y, size, mode=mag)
+ return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
+
+
+def scale_img_hwc(x, size, mag='bilinear', min='bilinear'):
+ return scale_img_nhwc(x[None, ...], size, mag, min)[0]
+
+
+def scale_img_nhw(x, size, mag='bilinear', min='bilinear'):
+ return scale_img_nhwc(x[..., None], size, mag, min)[..., 0]
+
+
+def scale_img_hw(x, size, mag='bilinear', min='bilinear'):
+ return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0]
+
+
+def trunc_rev_sigmoid(x, eps=1e-6):
+ x = x.clamp(eps, 1 - eps)
+ return torch.log(x / (1 - x))
+
+
+class MLP(nn.Module):
+
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
+ super().__init__()
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.dim_hidden = dim_hidden
+ self.num_layers = num_layers
+
+ net = []
+ for l in range(num_layers):
+ net.append(
+ nn.Linear(
+ self.dim_in if l == 0 else self.dim_hidden,
+ self.dim_out if l == num_layers - 1 else self.dim_hidden,
+ bias=bias))
+
+ self.net = nn.ModuleList(net)
+
+ def forward(self, x):
+ for l in range(self.num_layers):
+ x = self.net[l](x)
+ if l != self.num_layers - 1:
+ x = F.relu(x, inplace=True)
+ return x
+
+
+class Renderer(nn.Module):
+
+ def __init__(
+ self,
+ cfg,
+ num_layers_bg=2,
+ hidden_dim_bg=16,
+ ):
+
+ super().__init__()
+
+ self.cfg = cfg
+ self.min_near = cfg.model.min_near
+ # self.v_offsets = 0
+ # self.vn_offsets = 0
+
+ if not self.cfg.use_gl:
+ self.glctx = dr.RasterizeCudaContext() # support at most 2048 resolution.
+ else:
+ print('building gl context')
+ # # try:
+ self.glctx = dr.RasterizeGLContext() # will crash if using GUI...
+ # except:
+ # print('Failed to initialize GLContext, use CudaContext instead...')
+ # self.glctx = dr.RasterizeCudaContext() # support at most 2048 resolution.
+ # load the template mesh, will calculate normal and texture if not provided.
+ if self.cfg.model.use_color_network:
+ self.texture3d = ColorNetwork(cfg=cfg, num_layers=cfg.model.color_num_layers, hidden_dim=cfg.model.color_hidden_dim,
+ hash_max_res=cfg.model.color_hash_max_res, hash_num_levels=cfg.model.color_hash_num_levels)
+ else:
+ self.texture3d = None
+
+ # TODO: textrue 2D
+
+ if cfg.model.use_dmtet_network:
+ self.mesh = Mesh.load_obj(self.cfg.data.last_model, ref_path=self.cfg.data.last_ref_model, init_empty_tex=self.cfg.train.init_empty_tex, albedo_res=self.cfg.model.albedo_res, keypoints_path=self.cfg.data.keypoints_path, init_uv=False)
+ if self.mesh.keypoints is not None:
+ self.keypoints = self.mesh.keypoints
+ else:
+ self.keypoints = None
+ self.marching_tets = None
+ tet_v, tet_ind = build_tet_grid(self.mesh, cfg)
+ self.dmtet_network = DMTetMesh(vertices=torch.tensor(tet_v, dtype=torch.float), indices=torch.tensor(tet_ind, dtype=torch.long), grid_scale=self.cfg.model.tet_grid_scale, use_explicit=cfg.model.use_explicit_tet, geo_network=cfg.model.dmtet_network,
+ hash_max_res=cfg.model.geo_hash_max_res, hash_num_levels=cfg.model.geo_hash_num_levels, num_subdiv=cfg.model.tet_num_subdiv)
+ if self.cfg.train.init_mesh and not self.cfg.test.test:
+ self.dmtet_network.init_mesh(self.mesh.v, self.mesh.f, self.cfg.train.init_mesh_padding)
+ else:
+ self.mesh = Mesh.load_obj(self.cfg.data.last_model, ref_path=self.cfg.data.last_ref_model, init_empty_tex=self.cfg.train.init_empty_tex, albedo_res=self.cfg.model.albedo_res, use_vertex_tex=self.cfg.model.use_vertex_tex, keypoints_path=self.cfg.data.keypoints_path, init_uv=self.cfg.test.save_uv)
+ if self.mesh.keypoints is not None:
+ self.keypoints = self.mesh.keypoints
+ else:
+ self.keypoints = None
+ self.marching_tets = None
+ self.dmtet_network = None
+ self.mesh.v = self.mesh.v * self.cfg.model.mesh_scale
+ if cfg.train.init_texture_3d:
+ self.init_texture_3d()
+
+ if cfg.model.use_vertex_tex:
+ self.vertex_albedo = nn.Parameter(self.mesh.v_color)
+ # extract trainable parameters
+ if self.dmtet_network is None and not cfg.train.lock_geo:
+ self.sdf = nn.Parameter(torch.zeros_like(self.mesh.v[..., 0]))
+ self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))
+ self.vn_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))
+
+ if self.cfg.data.can_pose_folder:
+ import glob
+ if '.obj' in self.cfg.data.can_pose_folder:
+ can_pose_objs = [self.cfg.data.can_pose_folder]
+ else:
+ can_pose_objs = glob.glob(self.cfg.data.can_pose_folder + '/*.obj')
+ self.can_pose_vertices = []
+ self.can_pose_faces = []
+ self.can_pose_resize_inv = []
+ for pose_obj in can_pose_objs:
+ tri_mesh = trimesh.load(pose_obj)
+ mesh = Mesh(torch.tensor(tri_mesh.vertices, dtype=torch.float32).cuda(), torch.tensor(tri_mesh.faces, dtype=torch.int32).cuda())
+ mesh.auto_size()
+ self.can_pose_vertices.append(mesh.v)
+ self.can_pose_faces.append(mesh.f)
+
+ # background network
+ self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4)
+ if self.cfg.model.different_bg:
+ self.normal_bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
+ self.textureless_bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
+ else:
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
+
+
+ def init_texture_3d(self):
+ self.texture3d = self.texture3d.cuda()
+ optimizer = torch.optim.Adam(self.texture3d.parameters(), lr=0.01, betas=(0.9, 0.99),
+ eps=1e-15)
+ os.makedirs(self.cfg.workspace, exist_ok=True)
+ ckpt_path = os.path.join(self.cfg.workspace, 'init_tex.pth')
+ if os.path.exists(ckpt_path):
+ state_dict = torch.load(ckpt_path)
+ self.texture3d.load_state_dict(state_dict)
+ elif self.mesh.v_color is not None:
+ batch_size = 300000
+ num_epochs = 200
+ v_norm = self.mesh.vn
+ v_color = self.mesh.v_color
+ v_pos = self.mesh.v
+ num_pts = v_pos.shape[0]
+ print('start init texture 3d')
+ for i in range(num_epochs):
+ optimizer.zero_grad()
+ indice = random.sample(range(num_pts), min(batch_size, num_pts))
+ batch_pos = v_pos[indice].cuda()
+ batch_color = v_color[indice].cuda()
+ pred_color = self.texture3d(batch_pos)
+ loss_rgb = nn.functional.mse_loss(pred_color, batch_color)
+ loss = loss_rgb
+ loss.backward()
+ optimizer.step()
+ print('Iter {}: loss_norm: {}, loss_rgb: {}'.format(i, loss_norm.data, loss_rgb.data))
+ torch.save(self.texture3d.state_dict(), ckpt_path)
+
+
+ # optimizer utils
+ def get_params(self, lr):
+ # yapf: disable
+
+ if self.cfg.model.different_bg:
+ params += [
+ {'params': self.textureless_bg_net.parameters(), 'lr': lr},
+ {'params': self.normal_bg_net.parameters(), 'lr': lr},
+ ]
+ else:
+ params = [
+ {'params': self.bg_net.parameters(), 'lr': lr},
+ ]
+
+ if self.texture3d is not None:
+ params += [
+ {'params': self.texture3d.parameters(), 'lr': lr},
+ ]
+
+ if not self.cfg.train.lock_geo:
+ if self.dmtet_network is not None:
+ params.extend([
+ {'params': self.dmtet_network.parameters(), 'lr': lr*self.cfg.train.dmtet_lr}
+ ])
+ else:
+ params.extend([
+ {'params': self.v_offsets, 'lr': 0.0001},
+ {'params': self.vn_offsets, 'lr': 0.0001},
+ ])
+ # yapf: enable
+ if self.cfg.model.use_vertex_tex:
+ vertex_tex_lr = lr * 1
+ params.extend([
+ {'params': self.vertex_albedo, 'lr': vertex_tex_lr}
+ ])
+ print('vertex_tex_lr:', vertex_tex_lr)
+ return params
+
+
+ @torch.no_grad()
+ def export_mesh(self, save_path, name='mesh', export_uv=False):
+ self.resize_matrix_inv = self.mesh.resize_matrix_inv
+ if self.dmtet_network is not None:
+ num_subdiv = self.get_num_subdiv()
+ with torch.no_grad():
+ verts, faces, loss = self.dmtet_network.get_mesh(return_loss=False, num_subdiv=num_subdiv)
+ self.mesh = Mesh(v=verts, f=faces.int(), device='cuda', split=True)
+ self.mesh.albedo = torch.ones((2048, 2048, 3), dtype=torch.float).cuda()
+ if export_uv:
+ self.mesh.auto_uv()
+ self.mesh.auto_normal()
+ elif hasattr(self, 'v_offsets') and hasattr(self, 'vn_offsets'):
+ self.mesh.v = (self.mesh.v + self.v_offsets).detach()
+ self.mesh.vn = (self.mesh.vn + self.vn_offsets).detach() # TODO: may not be unit ?
+ else:
+ self.mesh.v = self.mesh.v
+ self.mesh.vn = self.mesh.vn
+ if export_uv:
+ if self.cfg.model.use_vertex_tex:
+ self.mesh.v_color = self.vertex_albedo.detach().clamp(0, 1)
+ elif self.cfg.model.use_texture_2d:
+ self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach())
+ else:
+ self.mesh.albedo = self.get_albedo_from_texture3d()
+ verts = torch.cat([self.mesh.v, torch.ones_like(self.mesh.v[:, :1])], dim=1) @ self.resize_matrix_inv.T
+ self.mesh.v = verts
+ self.mesh.write(os.path.join(save_path, '{}.obj'.format(name)))
+ if self.cfg.data.da_pose_mesh:
+ import trimesh
+ verts = self.mesh.v.new_tensor(trimesh.load(self.cfg.data.da_pose_mesh).vertices)
+ assert verts.shape[0] == self.mesh.v.shape[0], f"pose mesh verts: {self.mesh.v.shape[0]}, da pose mesh verts: {verts.shape[0]}"
+ self.mesh.v = verts
+ self.mesh.write(os.path.join(save_path, '{}_da.obj'.format(name)))
+
+
+ @torch.no_grad()
+ def get_albedo_from_texture3d(self):
+ h, w = self.mesh.albedo.shape[:2]
+ uv = self.mesh.vt *2.0 - 1.0
+ uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1)
+ print(uv.shape, self.mesh.ft.shape, h, w)
+ rast, rastdb = dr.rasterize(self.glctx, uv.unsqueeze(0), self.mesh.ft, (h, w)) # [1, h, w, 4]
+
+ if not self.cfg.model.use_can_pose_space:
+ color_space_v, color_space_f = self.mesh.v, self.mesh.f
+ else:
+ color_space_v, color_space_f = self.can_pose_vertices[0], self.can_pose_faces[0]
+
+ xyzs, _ = dr.interpolate(color_space_v.unsqueeze(0), rast, color_space_f) # [1, h, w, 3]
+ mask, _ = dr.interpolate(torch.ones_like(self.mesh.v[:, :1]).unsqueeze(0), rast, self.mesh.f) # [1, h, w, 1]
+ xyzs = xyzs.view(-1, 3)
+ mask = (mask > 0).view(-1)
+ #Image.fromarray((mask.reshape(h, w).cpu().numpy()*255).astype(np.uint8)).save('uv_map_mask.png')
+ feats = torch.zeros(h * w, 3, device='cuda', dtype=torch.float32)
+ batch_size = 300000
+ xyzs = xyzs[mask]
+ num_pts = xyzs.shape[0]
+ res = []
+ for i in range(0, num_pts, batch_size):
+ i_end = min(i + batch_size, num_pts)
+ batch_pts = xyzs[i:i_end]
+ pred_color = self.texture3d(batch_pts)
+ res.append(pred_color)
+ mask_feats = torch.cat(res, dim=0)
+ feats[mask] = mask_feats
+ feats = feats.reshape(h, w, 3)
+ feats = self.mesh.albedo.new_tensor(texture_padding((feats.reshape(h, w, 3).cpu().numpy()*255).astype(np.uint8), (mask.reshape(h, w).cpu().numpy()*255).astype(np.uint8))) / 255
+
+ return feats.reshape(self.mesh.albedo.shape)
+
+
+ def get_mesh(self, return_loss=True, detach_geo=False, global_step=1e7):
+ if self.marching_tets is None and self.dmtet_network is None:
+ return Mesh(v=self.mesh.v, base=self.mesh, device='cuda'), None
+ loss = None
+ if self.cfg.model.use_dmtet_network:
+ num_subdiv = self.get_num_subdiv(global_step=global_step)
+ verts, faces, loss = self.dmtet_network.get_mesh(return_loss=return_loss, num_subdiv=num_subdiv)
+ if detach_geo:
+ verts = verts.detach()
+ faces = faces.detach()
+ loss = None
+ mesh = Mesh(v=verts, f=faces.int(), device='cuda')
+ else:
+ v_deformed = self.mesh.v
+ if hasattr(self, 'v_offsets'):
+ v_deformed = v_deformed + 2 / (self.cfg.model.mesh_resolution * 2) * torch.tanh(self.v_offsets)
+ verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.sdf, self.mesh.f)
+ mesh = Mesh(v=verts, f=faces.int(), vt=uvs, ft=uv_idx.int())
+ mesh.auto_normal()
+ return mesh, loss
+
+ def get_color_from_vertex_texture(self, rast, rast_db, f, light_d, ambient_ratio, shading) -> Tensor:
+ albedo, _ = dr.interpolate(
+ self.vertex_albedo.unsqueeze(0).contiguous(), rast, f, rast_db=rast_db)
+ albedo = albedo.clamp(0., 1.)
+ albedo = torch.where(rast[..., 3:] > 0, albedo, torch.tensor(0.).to(albedo.device)) # remove background
+
+ if shading == 'albedo':
+ normal = None
+ color = albedo
+ else:
+ # NOTE: normal is hard... since we allow v to change, we have to recalculate normals all the time! and must care about flipped faces...
+ vn = self.mesh.vn
+ if hasattr(self, 'vn_offsets'):
+ vn = vn + self.vn_offsets
+ normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.f)
+ normal = safe_normalize(normal)
+
+ lambertian = ambient_ratio + (1 - ambient_ratio) * (normal @ light_d).float().clamp(min=0)
+
+ if shading == 'textureless':
+ color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3)
+ elif shading == 'normal':
+ color = (normal + 1) / 2
+ else: # 'lambertian'
+ color = albedo * lambertian.unsqueeze(-1)
+ return color
+
+ def get_color_from_mesh(self, mesh, rast, light_d, ambient_ratio, shading, poses=None):
+ vn = mesh.vn
+ normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, mesh.f)
+ normal = safe_normalize(normal)
+ #print('vn.grad {} normal.grad {}'.format(vn.requires_grad, normal is not None))
+ lambertian = ambient_ratio + (1 - ambient_ratio) * (normal @ light_d).float().clamp(min=0)
+
+ if shading == 'textureless':
+ color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3)
+ elif shading == 'normal':
+ if self.cfg.train.render_relative_normal and poses is not None:
+ normal_shape_old = normal.shape
+ B = poses.shape[0]
+ normal = torch.matmul(F.pad(normal, pad=(0, 1), mode='constant', value=1.0).reshape(B, -1, 4),
+ torch.transpose(torch.inverse(poses.cpu()).to(normal.device), 1, 2).reshape(B, 4, 4)).float()
+ normal = normal[..., :3].reshape(normal_shape_old)
+ normal = normal * normal.new_tensor([1, 1, -1])
+ color = (normal + 1) / 2
+ return color
+
+ def get_color_from_2d_texture(self, rast, rast_db, mesh, rays_o, light_d, ambient_ratio, shading) -> Tensor:
+ texc, texc_db = dr.interpolate(
+ self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
+ albedo = dr.texture(
+ self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear-mipmap-linear') # [1, H, W, 3]
+ # texc, _ = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft)
+ # albedo = dr.texture(self.albedo.unsqueeze(0), texc, filter_mode='linear') # [1, H, W, 3]
+ albedo = torch.sigmoid(albedo)
+ albedo = torch.where(rast[..., 3:] > 0, albedo, torch.tensor(0).to(albedo.device)) # remove background
+
+ if shading == 'albedo':
+ normal = None
+ color = albedo
+
+ else:
+
+ # NOTE: normal is hard... since we allow v to change, we have to recalculate normals all the time! and must care about flipped faces...
+ vn = mesh.vn
+ if hasattr(self, 'vn_offsets'):
+ vn = vn + self.vn_offsets
+ normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, mesh.f)
+ normal = safe_normalize(normal)
+
+ lambertian = ambient_ratio + (1 - ambient_ratio) * (normal @ light_d).float().clamp(min=0)
+
+ if shading == 'textureless':
+ color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3)
+ elif shading == 'normal':
+ color = (normal + 1) / 2
+ else: # 'lambertian'
+ color = albedo * lambertian.unsqueeze(-1)
+ return color
+
+ def get_color_from_3d_texture(self, rast, rast_db, v, f, vn, light_d, ambient_ratio, shading) -> Tensor:
+ xyzs, _ = dr.interpolate(v, rast, f, rast_db)
+ albedo= self.texture3d(xyzs.view(-1, 3))
+ if vn is not None:
+ normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)
+ normal = safe_normalize(normal)
+ if shading == 'albedo':
+ normal = None
+ color = albedo
+ else:
+ lambertian = ambient_ratio + (1 - ambient_ratio) * (normal @ light_d).float().clamp(min=0)
+ if shading == 'textureless':
+ color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3)
+ elif shading == 'normal':
+ color = (normal + 1) / 2
+ else: # 'lambertian'
+ color = albedo * lambertian.reshape(albedo.shape[:-1]).unsqueeze(-1)
+ return color.view(*rast.shape[:-1], 3)
+
+ def get_can_pos_map(self, rast, rast_db, f) -> Tensor:
+ #print(self.can_pose_vertices[0].shape)
+ xyzs, _ = dr.interpolate(self.can_pose_vertices[0], rast, f, rast_db)
+ #print(xyzs.shape)
+ return xyzs
+
+ def get_openpose_map(self, keypoints, depth, rgb=None):
+ keypoints = keypoints[0, 0]
+ keypoints = keypoints[:, :3] / keypoints[:, 3:]
+ from .pose_utils import draw_openpose_map
+ # print('depth.shape', depth.shape)
+ # print('keypoints', keypoints)
+ # print('depth.max()', depth.max())
+ H, W = depth.shape[:2]
+ keypoints_2d = (keypoints[:, :2] + 1.) / 2
+ keypoints_depth = keypoints[:, 2]
+ keypoints_2d_int = (keypoints_2d.clamp(0, 1.) * keypoints_2d.new_tensor([W, H])).to(torch.int)
+ keypoints_2d_int[:, 0] = keypoints_2d_int[:, 0].clamp(0, W-1)
+ keypoints_2d_int[:, 1] = keypoints_2d_int[:, 1].clamp(0, H-1)
+ keypoints_depth_proj = torch.zeros_like(keypoints_depth)
+ for i in range(len(keypoints_2d_int)):
+ keypoints_depth_proj[i] = depth[keypoints_2d_int[i, 1], keypoints_2d_int[i, 0], 0]
+ depth_diff_thres = (keypoints_depth[56:56+68].max(dim=0)[0] - keypoints_depth[56:56+68].min(dim=0)[0])/5
+ #print(depth_diff_thres)
+ keypoints_mask = (keypoints_2d[:, 0] < 1) & (keypoints_2d[:, 0] >= 0) & (keypoints_2d[:, 1] < 1) & (keypoints_2d[:, 1] >= 0) & (keypoints_depth < keypoints_depth_proj + depth_diff_thres)
+ if rgb is not None:
+ canvas = (rgb.detach().cpu().numpy().reshape(H, W, 3) * 255).astype(np.uint8)
+ else:
+ canvas = np.zeros((H, W, 3), dtype=np.uint8)
+ return draw_openpose_map(canvas, keypoints_2d.detach().cpu().numpy(), keypoints_mask.cpu().numpy())
+
+ def get_num_subdiv(self, global_step=1e7):
+ if self.cfg.train.tet_subdiv_steps is not None:
+ num_subdiv = 0
+ for step in self.cfg.train.tet_subdiv_steps:
+ if global_step >= step:
+ num_subdiv += 1
+ return num_subdiv
+ return self.cfg.model.tet_num_subdiv
+
+
+ def forward(self, rays_o, rays_d, mvp, h0, w0, light_d=None, ref_rgb=None, ambient_ratio=1.0, shading='albedo', return_loss=False, alpha_only=False, detach_geo=False, albedo_ref=False, poses=None, return_openpose_map=False, global_step=1e7, return_can_pos_map=False, mesh=None, can_pose=False):
+ # mvp: [1, 4, 4]
+ mvp = mvp.squeeze()
+ device = mvp.device
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ # do super-sampling
+ if self.cfg.model.render_ssaa > 1:
+ h = int(h0 * self.cfg.model.render_ssaa)
+ w = int(w0 * self.cfg.model.render_ssaa)
+ if not self.cfg.use_gl:
+ h = min(h, 2048)
+ w = min(w, 2048)
+ dirs = rays_d.view(h0, w0, 3)
+ dirs = scale_img_hwc(dirs, (h, w), mag='nearest').view(-1, 3).contiguous()
+ else:
+ h, w = h0, w0
+ dirs = rays_d
+
+ if self.cfg.model.single_bg_color:
+ dirs = torch.ones_like(dirs)
+
+ dirs = dirs / torch.norm(dirs, dim=-1, keepdim=True)
+ dirs[..., 0] = -dirs[..., 0]
+ dirs[..., 2] = -dirs[..., 2]
+
+ # mix background color
+ if self.cfg.model.different_bg and shading == 'textureless':
+ bg_color = torch.sigmoid(self.textureless_bg_net(self.encoder_bg(dirs))).view(h, w, 3)
+ elif self.cfg.model.different_bg and shading == 'normal':
+ bg_color = torch.sigmoid(self.normal_bg_net(self.encoder_bg(dirs))).view(h, w, 3)
+ else:
+ bg_color = torch.sigmoid(self.bg_net(self.encoder_bg(dirs))).view(h, w, 3)
+
+ results = {}
+ geo_reg_loss = None
+ if mesh is None:
+ mesh, geo_reg_loss = self.get_mesh(return_loss=return_loss, detach_geo=detach_geo, global_step=global_step)
+
+ results['mesh'] = mesh
+ v = mesh.v # [N, 3]
+ f = mesh.f
+ if can_pose:
+ v, f = random.choice(list(zip(self.can_pose_vertices, self.can_pose_faces)))
+
+ v_clip = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0),
+ torch.transpose(mvp, 0, 1)).float().unsqueeze(0) # [1, N, 4]
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
+
+ mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, H, W, 1]
+ alpha = mask.clone()
+ if alpha_only:
+ alpha = dr.antialias(alpha, rast, v_clip, mesh.f).squeeze(0).clamp(0, 1) # [H, W, 3]
+ if self.cfg.model.render_ssaa > 1:
+ alpha = scale_img_hwc(alpha, (h0, w0))
+ return dict(alpha=alpha)
+ # xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, self.mesh.f) # [1, H, W, 3]
+ # xyzs = xyzs.view(-1, 3)
+ # mask = (mask > 0).view(-1)
+ # albedo = torch.zeros_like(xyzs, dtype=torch.float32)
+ # if mask.any():
+ # masked_albedo = torch.sigmoid(self.color_net(self.encoder(xyzs[mask].detach(), bound=1)))
+ # albedo[mask] = masked_albedo.float()
+ # albedo = albedo.view(1, h, w, 3)cuda
+ if not self.cfg.model.use_can_pose_space:
+ color_space_v, color_space_f = mesh.v, mesh.f
+ else:
+ color_space_v, color_space_f = self.can_pose_vertices[0], self.can_pose_faces[0]
+
+
+ if shading != 'albedo' and light_d is None: # random sample light_d if not provided
+ # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
+ light_d = rays_o[0] + torch.randn(3, device=rays_o.device, dtype=torch.float)
+ #light_d = random.choice(-rays_d.view(-1, 3))#(rays_o[0] + torch.randn(3, device=rays_o.device, dtype=torch.float))
+ light_d = safe_normalize(light_d)
+ if shading in ['normal', 'textureless']:
+ color = self.get_color_from_mesh(mesh, rast, light_d, ambient_ratio, shading, poses=poses)
+ elif self.cfg.model.use_texture_2d:
+ color = self.get_color_from_2d_texture(rast, rast_db, mesh, rays_o, light_d, ambient_ratio, shading)
+ elif self.cfg.model.use_vertex_tex:
+ color = self.get_color_from_vertex_texture(rast, rast_db, color_space_f, light_d, ambient_ratio, shading)
+ else:
+ color = self.get_color_from_3d_texture(rast, rast_db, color_space_v, color_space_f, mesh.vn, light_d, ambient_ratio, shading)
+
+ color = dr.antialias(color, rast, v_clip, mesh.f).squeeze(0).clamp(0, 1) # [H, W, 3]
+ alpha = dr.antialias(alpha, rast, v_clip, mesh.f).squeeze(0).clamp(0, 1) # [H, W, 3]
+ # color = color.squeeze(0).clamp(0, 1)
+ # alpha = alpha.squeeze(0).clamp(0, 1)
+ depth = rast[0, :, :, [2]] # [H, W]
+
+ color = color * alpha + (1 - alpha) * bg_color
+
+ # ssaa
+
+ if albedo_ref and not (self.cfg.model.use_vertex_tex) and (not self.cfg.model.use_texture_2d):
+ with torch.no_grad():
+ albedo = self.get_color_from_3d_texture(rast.detach(), rast_db.detach(), color_space_v.detach(), color_space_f, None, light_d, 1.0, 'albedo')
+ albedo = dr.antialias(albedo, rast, v_clip, mesh.f).squeeze(0).clamp(0, 1) # [H, W, 3]
+ albedo = albedo * alpha + (1 - alpha) * bg_color
+ if self.cfg.model.render_ssaa > 1:
+ albedo = scale_img_hwc(albedo, (h0, w0))
+ results['albedo_ref'] = albedo
+
+ if self.cfg.model.render_ssaa > 1:
+ color = scale_img_hwc(color, (h0, w0))
+ alpha = scale_img_hwc(alpha, (h0, w0))
+ depth = scale_img_hwc(depth, (h0, w0))
+ bg_color = scale_img_hwc(bg_color, (h0, w0))
+
+ results['depth'] = depth
+ results['image'] = color
+ results['alpha'] = alpha
+ results['bg_color'] = bg_color
+ if geo_reg_loss is not None:
+ results['geo_reg_loss'] = geo_reg_loss
+
+ if return_openpose_map:
+ keypoints_2d = torch.matmul(F.pad(self.keypoints, pad=(0, 1), mode='constant', value=1.0),
+ torch.transpose(mvp, 0, 1)).float().unsqueeze(0) # [1, N, 4]
+ results['openpose_map'] = depth.new_tensor(self.get_openpose_map(keypoints_2d, depth, color if self.cfg.test.test else None)) / 255
+ # results['image'] = torch.flip(results['image'], dims=[-2])
+ # results['openpose_map'] = torch.flip(results['openpose_map'], dims=[-2])
+ if return_can_pos_map:
+ results['can_pos_map'] = self.get_can_pos_map(rast.detach(), rast_db.detach(), mesh.f)
+ if self.cfg.model.render_ssaa > 1:
+ results['can_pos_map'] = scale_img_hwc(results['can_pos_map'], (h0, w0))
+
+
+ return results
\ No newline at end of file
diff --git a/core/lib/tet_utils.py b/core/lib/tet_utils.py
new file mode 100755
index 0000000..16d4d15
--- /dev/null
+++ b/core/lib/tet_utils.py
@@ -0,0 +1,45 @@
+import pyvista as pv
+import pymeshlab
+import tetgen
+import os.path as osp
+import os
+import numpy as np
+
+def build_tet_grid(mesh, cfg):
+ assert cfg.data.last_model.split('.')[-1] == 'obj'
+ tet_dir = osp.join(cfg.workspace, 'tet')
+ os.makedirs(tet_dir, exist_ok=True)
+ save_path = osp.join(tet_dir, 'tet_grid.npz')
+ if osp.exists(save_path):
+ print('Loading exist tet grids from {}'.format(save_path))
+ tets = np.load(save_path)
+ vertices = tets['vertices']
+ indices = tets['indices']
+ print('shape of vertices: {}, shape of grids: {}'.format(vertices.shape, indices.shape))
+ return vertices, indices
+ print('Building tet grids...')
+ tet_flag = False
+ tet_shell_offset = cfg.model.tet_shell_offset
+ while (not tet_flag) and tet_shell_offset > cfg.model.tet_shell_offset / 16:
+ # try:
+ ms = pymeshlab.MeshSet()
+ ms.add_mesh(pymeshlab.Mesh(mesh.v.cpu().numpy(), mesh.f.cpu().numpy()))
+ ms.generate_resampled_uniform_mesh(offset=pymeshlab.AbsoluteValue(tet_shell_offset))
+ ms.save_current_mesh(osp.join(tet_dir, 'dilated_mesh.obj'))
+ mesh = pv.read(osp.join(tet_dir, 'dilated_mesh.obj'))
+ downsampled_mesh = mesh.decimate(cfg.model.tet_shell_decimate)
+ tet = tetgen.TetGen(downsampled_mesh)
+ tet.make_manifold(verbose=True)
+ vertices, indices = tet.tetrahedralize( fixedvolume=1,
+ maxvolume=cfg.model.tet_grid_volume,
+ regionattrib=1,
+ nobisect=False, steinerleft=-1, order=1, metric=1, meditview=1, nonodewritten=0, verbose=2)
+ shell = tet.grid.extract_surface()
+ shell.save(osp.join(tet_dir, 'shell_surface.ply'))
+ np.savez(save_path, vertices=vertices, indices=indices)
+ print('shape of vertices: {}, shape of grids: {}'.format(vertices.shape, indices.shape))
+ tet_flag = True
+ # except:
+ # tet_shell_offset /= 2
+ assert tet_flag, "Failed to initialize tetrahedra grid!"
+ return vertices, indices
\ No newline at end of file
diff --git a/core/lib/trainer.py b/core/lib/trainer.py
new file mode 100755
index 0000000..68f081a
--- /dev/null
+++ b/core/lib/trainer.py
@@ -0,0 +1,1143 @@
+import os
+import glob
+import tqdm
+import imageio
+import random
+import tensorboardX
+
+import numpy as np
+
+import time
+
+import cv2
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.distributed as dist
+from PIL import Image
+from rich.console import Console
+from torch_ema import ExponentialMovingAverage
+
+from .chamfer import chamfer_distance
+
+from .annotators import HEDdetector, Cannydetector
+from .color_utils import convert_rgb
+from .loss_utils import *
+from .camera_utils import *
+import math
+
+from thirdparties.lpips import LPIPS
+
+def scale_for_lpips(image_tensor):
+ return image_tensor * 2. - 1.
+
+
+def seed_everything(seed):
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+
+class Trainer(object):
+
+ def __init__(
+ self,
+ name, # name of this experiment
+ cfg, # extra conf
+ model, # network
+ guidance, # guidance network
+ criterion=None, # loss function, if None, assume inline implementation in train_step
+ optimizer=None, # optimizer
+ ema_decay=None, # if use EMA, set the decay
+ lr_scheduler=None, # scheduler
+ metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
+ local_rank=0, # which GPU am I
+ world_size=1, # total num of GPUs
+ device=None, # device to use, usually setting to None is OK. (auto choose device)
+ mute=False, # whether to mute all print
+ fp16=False, # amp optimize level
+ eval_interval=1, # eval once every $ epoch
+ max_keep_ckpt=2, # max num of saved ckpts in disk
+ workspace='workspace', # workspace to save logs & ckpts
+ best_mode='min', # the smaller/larger result, the better
+ use_loss_as_metric=True, # use loss as the first metric
+ report_metric_at_train=False, # also report metrics at training
+ use_checkpoint="latest", # which ckpt to use at init time
+ pretrained=None,
+ use_tensorboardX=True, # whether to use tensorboard for logging
+ scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
+ ):
+
+ self.name = name
+ self.cfg = cfg
+ self.stage = self.cfg.stage
+ self.mute = mute
+ self.metrics = metrics
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.workspace = workspace
+ self.ema_decay = ema_decay
+ self.fp16 = fp16
+ self.best_mode = best_mode
+ self.use_loss_as_metric = use_loss_as_metric
+ self.report_metric_at_train = report_metric_at_train
+ self.max_keep_ckpt = max_keep_ckpt
+ self.eval_interval = eval_interval
+ self.use_checkpoint = use_checkpoint
+ self.use_tensorboardX = use_tensorboardX
+ self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
+ self.scheduler_update_every_step = scheduler_update_every_step
+ self.device = device if device is not None else torch.device(
+ f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
+ self.console = Console()
+
+ model.to(self.device)
+ model.mesh.to(self.device)
+ if self.world_size > 1:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
+ self.model = model
+ if self.cfg.data.img:
+ self._load_input_image()
+
+ # guide model
+ self.guidance = guidance
+
+ # text prompt
+ if self.guidance is not None:
+
+ self.prepare_text_embeddings()
+
+ else:
+ self.text_z = None
+
+ # try out torch 2.0
+ if torch.__version__[0] == '2' and torch.cuda.get_device_capability(self.device)[0] >= 7:
+ self.model = torch.compile(self.model)
+ self.guidance = torch.compile(self.guidance)
+
+ if isinstance(criterion, nn.Module):
+ criterion.to(self.device)
+ self.criterion = criterion
+
+ self.optimizer_fn = optimizer
+ self.lr_scheduler_fn = lr_scheduler
+
+ if optimizer is None:
+ self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
+ else:
+ self.optimizer = optimizer(self.model)
+
+ if lr_scheduler is None:
+ self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
+ else:
+ self.lr_scheduler = lr_scheduler(self.optimizer)
+
+ if ema_decay is not None:
+ self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
+ else:
+ self.ema = None
+
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
+
+ # variable init
+ self.epoch = 0
+ self.global_step = 0
+ self.local_step = 0
+ self.stats = {
+ "loss": [],
+ "valid_loss": [],
+ "results": [], # metrics[0], or valid_loss
+ "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
+ "best_result": None,
+ }
+
+ # auto fix
+ if len(metrics) == 0 or self.use_loss_as_metric:
+ self.best_mode = 'min'
+
+ # workspace prepare
+ self.log_ptr = None
+ if self.workspace is not None:
+ os.makedirs(self.workspace, exist_ok=True)
+ self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
+ self.log_ptr = open(self.log_path, "a+")
+
+ self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
+ self.best_path = f"{self.ckpt_path}/{self.name}.pth"
+ os.makedirs(self.ckpt_path, exist_ok=True)
+
+ self.log(
+ f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}'
+ )
+ self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
+
+ self.load_pretrained(pretrained=pretrained)
+
+ if self.workspace is not None:
+ if self.use_checkpoint == "scratch":
+ self.log("[INFO] Training from scratch ...")
+ elif self.use_checkpoint == "latest":
+ self.log("[INFO] Loading latest checkpoint ...")
+ self.load_checkpoint()
+ elif self.use_checkpoint == "latest_model":
+ self.log("[INFO] Loading latest checkpoint (model only)...")
+ self.load_checkpoint(model_only=True)
+ elif self.use_checkpoint == "best":
+ if os.path.exists(self.best_path):
+ self.log("[INFO] Loading best checkpoint ...")
+ self.load_checkpoint(self.best_path)
+ else:
+ self.log(f"[INFO] {self.best_path} not found, loading latest ...")
+ self.load_checkpoint()
+ else: # path to ckpt
+ self.log(f"[INFO] Loading {self.use_checkpoint} ...")
+ self.load_checkpoint(self.use_checkpoint)
+ if self.cfg.train.lambda_recon > 0 or self.cfg.train.lambda_normal > 0.:
+ self.lpips_model = LPIPS(net='vgg').cuda()
+ for param in self.lpips_model.parameters():
+ param.requires_grad = False
+ if self.cfg.guidance.controlnet_guidance_geometry:
+ if self.cfg.guidance.controlnet_guidance_geometry == 'hed':
+ self.controlnet_annotator = HEDdetector()
+ elif self.cfg.guidance.controlnet_guidance_geometry == 'canny':
+ self.controlnet_annotator = Cannydetector(100, 200)
+ else:
+ raise NotImplementedError
+ self.render_openpose_training = self.cfg.guidance.controlnet_openpose_guidance
+ self.render_openpose = self.cfg.guidance.controlnet_openpose_guidance
+
+ def _load_input_image(self):
+ self.input_image = Image.open(self.cfg.data.img)
+ if self.input_image.width > 2048 or self.input_image.height > 2048:
+ self.input_image = self.input_image.resize((2048, 2048))
+ self.input_image = np.array(self.input_image) / 255
+ self.input_mask = torch.tensor(self.input_image[..., 3], dtype=torch.float).to(self.device).unsqueeze(0)
+ self.input_mask_edt = get_edt(self.input_mask.unsqueeze(0))[0]
+ self.input_image = torch.tensor(self.input_image[..., :3], dtype=torch.float).to(self.device).permute(2, 0, 1)
+ self.input_image = self.input_image * self.input_mask
+ self.model.input_image = self.input_image
+ self.model.input_mask = self.input_mask
+ if self.cfg.data.front_normal_img is not None:
+ self.normal_image = np.array(Image.open(self.cfg.data.front_normal_img)) / 255
+ self.normal_mask = torch.tensor(self.normal_image[..., 3], dtype=torch.float).to(self.device).unsqueeze(0)
+ self.normal_mask_edt = get_edt(self.normal_mask.unsqueeze(0))[0]
+ self.normal_image = torch.tensor(self.normal_image[..., :3], dtype=torch.float).to(self.device).permute(2, 0, 1)
+ self.normal_image = self.normal_image * self.normal_mask
+ else:
+ self.normal_mask = None
+ self.normal_image = None
+ if self.cfg.data.back_normal_img is not None:
+ self.back_normal_image = np.array(Image.open(self.cfg.data.back_normal_img)) / 255
+ self.back_normal_mask = torch.tensor(self.back_normal_image[..., 3], dtype=torch.float).to(self.device).unsqueeze(0)
+ self.back_normal_mask_edt = get_edt(self.back_normal_mask.unsqueeze(0))[0]
+ self.back_normal_image = torch.tensor(self.back_normal_image[..., :3], dtype=torch.float).to(self.device).permute(2, 0, 1)
+ self.back_normal_image = self.back_normal_image * self.back_normal_mask
+ else:
+ self.back_normal_mask = None
+ self.back_normal_image = None
+ if self.cfg.data.loss_mask is not None:
+ self.loss_mask = np.array(Image.open(self.cfg.data.loss_mask).resize(self.input_image.shape[1:]))[..., -1] / 255
+ self.loss_mask_norm = np.array(Image.open(self.cfg.data.loss_mask).resize((512, 512)))[..., -1] / 255
+ self.loss_mask = torch.tensor(self.loss_mask, dtype=torch.float).to(self.device).unsqueeze(0) * self.input_mask
+ self.loss_mask_norm = torch.tensor(self.loss_mask_norm, dtype=torch.float).to(self.device).unsqueeze(0) * self.normal_mask
+ elif self.cfg.data.occ_mask is not None:
+ self.loss_mask = torch.tensor(self.get_loss_mask(self.cfg.data.occ_mask, self.cfg.data.seg_mask, self.input_image.shape[-2:]), dtype=torch.float).to(self.device) * self.input_mask
+ self.loss_mask_norm = torch.tensor(self.get_loss_mask(self.cfg.data.occ_mask, self.cfg.data.seg_mask, self.normal_mask.shape[-2:]), dtype=torch.float).to(self.device) * self.normal_mask
+ else:
+ self.loss_mask = torch.ones_like(self.input_mask)
+ self.loss_mask_norm = torch.ones_like(self.normal_mask)
+ if self.cfg.train.loss_mask_erosion is not None:
+ kernel = np.ones((self.cfg.train.loss_mask_erosion, self.cfg.train.loss_mask_erosion), np.float32)
+ self.erosion_mask = torch.tensor(cv2.erode(self.input_mask.cpu().numpy()[0], kernel, cv2.BORDER_REFLECT)).to(self.device).unsqueeze(0)
+ norm_kernel = np.ones((self.cfg.train.loss_mask_erosion//2, self.cfg.train.loss_mask_erosion//2), np.float32)
+ if self.normal_mask is not None:
+ self.erosion_normal_mask = torch.tensor(cv2.erode(self.normal_mask.cpu().numpy()[0], norm_kernel, cv2.BORDER_REFLECT)).to(self.device).unsqueeze(0)
+ else:
+ self.erosion_normal_mask = None
+ if self.back_normal_mask is not None:
+ self.erosion_back_normal_mask = torch.tensor(cv2.erode(self.back_normal_mask.cpu().numpy()[0], norm_kernel, cv2.BORDER_REFLECT)).to(self.device).unsqueeze(0)
+ else:
+ self.erosion_back_normal_mask = None
+ else:
+ self.erosion_mask = None
+ self.erosion_normal_mask = None
+ self.erosion_back_normal_mask = None
+
+ self.input_can_pos_map = None
+
+ def get_loss_mask(self, occ_map_path, seg_path, img_size):
+ occ_map = np.array(Image.open(occ_map_path).resize(img_size)) / 255
+ if len(occ_map.shape) == 3:
+ occ_map = occ_map[..., -1]
+ if seg_path is not None:
+ seg_map = np.array(Image.open(seg_path).resize(img_size)) / 255
+ if len(seg_map.shape) == 3:
+ seg_map = seg_map[..., -1]
+ occ_map = (occ_map > 0) and (seg_map == 0)
+ return occ_map
+
+ # calculate the text embs.
+ def prepare_text_embeddings(self):
+
+ if self.cfg.guidance.text is None:
+ self.log(f"[WARN] text prompt is not provided.")
+ self.text_z = None
+ return
+
+ if not self.cfg.guidance.use_view_prompt:
+ self.text_z = self.guidance.get_text_embeds([self.cfg.guidance.text], [self.cfg.guidance.negative])
+ else:
+ print('get rgb text prompt')
+ self.text_z_novd = self.guidance.get_text_embeds([self.cfg.guidance.text], [self.cfg.guidance.negative])
+ self.text_z = []
+ for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
+ # construct dir-encoded text
+ text = f"{self.cfg.guidance.text}, {d} view, {self.cfg.guidance.text_extra}"
+
+ negative_text = f"{self.cfg.guidance.negative}"
+
+ # explicit negative dir-encoded text
+ text_z = self.guidance.get_text_embeds([text], [negative_text])
+ self.text_z.append(text_z)
+ if self.cfg.train.face_sample_ratio > 0.:
+ self.face_text_z_novd = self.guidance.get_text_embeds([f"the face of {self.cfg.guidance.text}, {self.cfg.guidance.text_extra}"], [self.cfg.guidance.negative])
+ self.face_text_z = []
+ prompt = self.cfg.guidance.text_head if (self.cfg.guidance.text_head is not None) and (len(self.cfg.guidance.text_head) > 0) else self.cfg.guidance.text
+ for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
+ # construct dir-encoded text
+ text = f"the face of {prompt}, {d} view, {self.cfg.guidance.text_extra}"
+
+ negative_text = f"{self.cfg.guidance.negative_normal}"
+
+ # explicit negative dir-encoded text
+ text_z = self.guidance.get_text_embeds([text], [negative_text], is_face=True)
+ self.face_text_z.append(text_z)
+ if (self.cfg.guidance.normal_text is not None) and (len(self.cfg.guidance.normal_text) > 0):
+ print('get normal text prompt')
+ basic_prompt = self.cfg.guidance.text if (self.cfg.guidance.text_geo is None) or (len(self.cfg.guidance.text_geo)==0) else self.cfg.guidance.text_geo
+ self.normal_text_z_novd = self.guidance.get_text_embeds([f"{self.cfg.guidance.normal_text} of {basic_prompt}, {self.cfg.guidance.normal_text_extra}"], [self.cfg.guidance.negative_normal])
+ self.normal_text_z = []
+ for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
+ # construct dir-encoded text
+ text = f"{self.cfg.guidance.normal_text} of {basic_prompt}, {d} view, {self.cfg.guidance.normal_text_extra}"
+
+ negative_text = f"{self.cfg.guidance.negative_normal}"
+
+ # explicit negative dir-encoded text
+ text_z = self.guidance.get_text_embeds([text], [negative_text])
+ self.normal_text_z.append(text_z)
+ self.face_normal_text_z_novd = self.guidance.get_text_embeds([f"{self.cfg.guidance.normal_text} of the face of {basic_prompt}, {self.cfg.guidance.normal_text_extra}"], [self.cfg.guidance.negative_normal])
+ self.face_normal_text_z = []
+ basic_prompt = self.cfg.guidance.text_head if (self.cfg.guidance.text_head is not None) and (len(self.cfg.guidance.text_head) > 0) else basic_prompt
+ for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
+ # construct dir-encoded text
+ text = f"{self.cfg.guidance.normal_text} of the face of {basic_prompt}, {d} view, {self.cfg.guidance.normal_text_extra}"
+
+ negative_text = f"{self.cfg.guidance.negative_normal}"
+
+ # explicit negative dir-encoded text
+ text_z = self.guidance.get_text_embeds([text], [negative_text])
+ self.face_normal_text_z.append(text_z)
+ if (self.cfg.guidance.textureless_text is not None) and (len(self.cfg.guidance.textureless_text))>0:
+ print('get textureless text prompt')
+ self.textureless_text_z_novd = self.guidance.get_text_embeds([f"{self.cfg.guidance.textureless_text} of {self.cfg.guidance.text}, {self.cfg.guidance.textureless_text_extra}"], [self.cfg.guidance.negative_textureless])
+ self.textureless_text_z = []
+ for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
+ # construct dir-encoded text
+ text = f"{self.cfg.guidance.textureless_text} of {self.cfg.guidance.text}, {d} view, {self.cfg.guidance.textureless_text_extra}"
+
+ negative_text = f"{self.cfg.guidance.negative_textureless}"
+
+ # explicit negative dir-encoded text
+ text_z = self.guidance.get_text_embeds([text], [negative_text])
+ self.textureless_text_z.append(text_z)
+ self.face_textureless_text_z_novd = self.guidance.get_text_embeds([f"{self.cfg.guidance.textureless_text} of the face of {self.cfg.guidance.text}, {self.cfg.guidance.textureless_text_extra}"], [self.cfg.guidance.negative_textureless])
+ self.face_textureless_text_z = []
+ for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
+ # construct dir-encoded text
+ text = f"{self.cfg.guidance.textureless_text} of the face of {self.cfg.guidance.text}, {d} view, {self.cfg.guidance.textureless_text_extra}"
+
+ negative_text = f"{self.cfg.guidance.negative_textureless}"
+
+ # explicit negative dir-encoded text
+ text_z = self.guidance.get_text_embeds([text], [negative_text])
+ self.face_textureless_text_z.append(text_z)
+ def __del__(self):
+ if self.log_ptr:
+ self.log_ptr.close()
+
+ def log(self, *args, **kwargs):
+ if self.local_rank == 0:
+ if not self.mute:
+ #print(*args)
+ self.console.print(*args, **kwargs)
+ if self.log_ptr:
+ print(*args, file=self.log_ptr)
+ self.log_ptr.flush() # write immediately to file
+
+ ### ------------------------------
+
+ def train_step(self, data):
+
+ rand1 = random.random()
+ flag_train_geometry =(not self.cfg.train.lock_geo)
+ if self.cfg.train.train_both:
+ shadings = ['normal', 'albedo']
+ ambient_ratio = 1.0
+ elif rand1 < self.cfg.train.normal_sample_ratio and flag_train_geometry:
+ shadings = ['normal']
+ ambient_ratio = 0.1
+ elif rand1 < self.cfg.train.textureless_sample_ratio and flag_train_geometry:
+ shadings = ['textureless']
+ ambient_ratio = 0.1
+ else:
+ rand = random.random()
+ if rand < self.cfg.train.albedo_sample_ratio:
+ shadings = ['albedo']
+ ambient_ratio = 1.0
+ else:
+ shadings = ['lambertian']
+ ambient_ratio = 0.1
+ loss = 0
+ step_mesh = None
+ for i_shading, shading in enumerate(shadings):
+ mvp = data['mvp']
+ poses = data['poses']
+ H, W = data['H'], data['W']
+
+ rays = get_rays(data['poses'], data['intrinsics'], H, W, -1)
+ rays_o = rays['rays_o'] # [B, N, 3]
+ rays_d = rays['rays_d'] # [B, N, 3]
+ outputs = self.model(rays_o, rays_d, mvp, H, W,
+ poses=poses,
+ ambient_ratio=ambient_ratio,
+ shading=shading,
+ return_openpose_map=self.render_openpose_training,
+ global_step=self.global_step,
+ can_pose=data['can_pose'],
+ mesh=step_mesh)
+ pred_rgb = outputs['image'].reshape(1, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
+ pred_alpha = outputs['alpha'].reshape(1, H, W, 1).permute(0, 3, 1, 2).contiguous() # [1, 1, H, W]
+ pred_depth = outputs['depth'].reshape(1, H, W)
+ if step_mesh is None:
+ step_mesh = outputs['mesh']
+
+ # text embeddings
+ if self.cfg.guidance.use_view_prompt:
+ dirs = data['dir'] # [B,]
+ is_face = data['is_face']
+ if (self.cfg.guidance.normal_text is not None) and len(self.cfg.guidance.normal_text) > 0 and shading == 'normal':
+ if is_face:
+ text_z_novd = self.face_normal_text_z_novd
+ text_z = self.face_normal_text_z[dirs]
+ else:
+ text_z_novd = self.normal_text_z_novd
+ text_z = self.normal_text_z[dirs]
+ elif (self.cfg.guidance.textureless_text is not None) and len(self.cfg.guidance.textureless_text) > 0 and shading == 'textureless':
+ if is_face:
+ text_z_novd = self.face_textureless_text_z_novd
+ text_z = self.face_textureless_text_z[dirs]
+ else:
+ text_z_novd = self.textureless_text_z_novd
+ text_z = self.textureless_text_z[dirs]
+ else:
+ if is_face:
+ text_z_novd = self.face_text_z_novd
+ text_z = self.face_text_z[dirs]
+ else:
+ text_z_novd = self.text_z_novd
+ text_z = self.text_z[dirs]
+ else:
+ text_z = self.text_z
+ text_z_novd = self.text_z
+
+ if self.cfg.guidance.controlnet_openpose_guidance:
+ controlnet_hint = Image.fromarray((outputs['openpose_map'].detach().cpu().numpy() * 255).astype(np.uint8))
+ loss = loss + self.guidance.train_step(text_z, pred_rgb, guidance_scale=self.cfg.guidance.guidance_scale,
+ controlnet_conditioning_scale=self.cfg.guidance.controlnet_conditioning_scale, controlnet_hint=controlnet_hint,
+ poses=data['poses'], text_embedding_novd=text_z_novd, is_face=is_face)
+ else:
+ loss = loss + self.guidance.train_step(text_z, pred_rgb, guidance_scale=self.cfg.guidance.guidance_scale,
+ poses=data['poses'], text_embedding_novd=text_z_novd, is_face=is_face)
+
+ output_images_novel = outputs['image']
+ output_alpha_novel = outputs['alpha']
+
+ # regularizations
+ # smoothness
+ mesh = outputs['mesh']
+ _mesh = None
+ if i_shading == 0:
+ if flag_train_geometry and self.cfg.train.lambda_lap > 0:
+ loss_lap = laplacian_smooth_loss(mesh.v, mesh.f.long())
+ loss = loss + self.cfg.train.lambda_lap * loss_lap
+
+ if (self.normal_image is not None) and shading == 'normal':
+ if self.back_normal_image is not None:
+ recon_image, recon_mask, recon_mask_edt, flip = random.choice([(self.normal_image, self.normal_mask, self.normal_mask_edt, False), (self.back_normal_image, self.back_normal_mask, self.back_normal_mask_edt, True)])
+ else:
+ recon_image = self.normal_image
+ recon_mask = self.normal_mask
+ recon_mask_edt = self.normal_mask_edt
+ flip = False
+ else:
+ recon_image = self.input_image
+ recon_mask = self.input_mask
+ recon_mask_edt = self.input_mask_edt
+ flip = False
+
+ # calculate reconstruction loss
+ TO_WORLD = np.eye(
+ 4,
+ dtype=np.float32,
+ )
+ TO_WORLD[2,2] = -1
+ TO_WORLD[1,1] = -1
+
+ H, W = recon_image.shape[1:]
+ intrinsics = torch.tensor([H, W, H/2, W/2])[None]
+ mvp = mvp.new_tensor(np.linalg.inv(TO_WORLD)) @ self.model.mesh.resize_matrix_inv
+ poses = torch.tensor(TO_WORLD, dtype=torch.float32)[None]
+ rays = get_rays(poses, intrinsics, H, W, -1)
+ rays_o = rays['rays_o'].to(self.device) # [B, N, 3]
+ rays_d = rays['rays_d'].to(self.device) # [B, N, 3]
+
+ if flip:
+ flip_mat = torch.eye(4).to(mvp)
+ flip_mat[2, 2] = -1
+ mvp = flip_mat @ mvp
+
+ if shading in ['textureless', 'normal']:
+ outputs = self.model(rays_o, rays_d, mvp, H, W, alpha_only=False, shading=shading,
+ global_step=self.global_step,
+ mesh=step_mesh)
+ pred_alpha = outputs['alpha'].reshape(1, H, W, 1).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
+ pred_rgb = outputs['image'].reshape(1, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
+ mask = recon_mask.unsqueeze(0)
+ mask_edt = recon_mask_edt.unsqueeze(0)
+ gt = recon_image.unsqueeze(0)
+ if self.loss_mask_norm is not None and self.normal_image is not None:
+ loss_mask_norm = self.loss_mask_norm if not flip else torch.flip(self.loss_mask_norm, dims=[-1])
+ mask = mask * loss_mask_norm
+ pred_alpha = pred_alpha * loss_mask_norm
+ elif self.loss_mask is not None:
+ mask = mask * self.loss_mask
+ pred_alpha = pred_alpha * self.loss_mask
+
+ if self.cfg.train.lambda_sil > 0.:
+ l_sil = silhouette_loss(pred_alpha.reshape(1, 1, H, W), mask.reshape(1, 1, H, W), edt=mask_edt, l2_weight=1., edge_weight=0.) * self.cfg.train.lambda_sil
+ loss = loss + l_sil
+ if self.normal_image is not None and self.cfg.train.lambda_normal > 0:
+ if self.erosion_normal_mask is not None:
+ mask = mask * (self.erosion_normal_mask if not flip else self.erosion_back_normal_mask)
+ lpips_loss = self.lpips_model(scale_for_lpips(pred_rgb*mask),
+ scale_for_lpips(gt*mask))
+ mse_loss = nn.functional.mse_loss(pred_rgb*mask, gt * mask) *0.2
+ decay_ratio = 1.
+ if self.cfg.train.decay_lnorm_cosine_cycle is not None:
+ if self.global_step > self.cfg.train.decay_lnorm_cosine_max_iter:
+ decay_ratio = 0.
+ else:
+ t = (self.global_step % self.cfg.train.decay_lnorm_cosine_cycle) / self.cfg.train.decay_lnorm_cosine_cycle
+ decay_ratio = (1 + math.cos(t * math.pi)) / 2
+ else:
+ for step, ratio in zip(self.cfg.train.decay_lnorm_iter, self.cfg.train.decay_lnorm_ratio):
+ if self.global_step > step:
+ decay_ratio = ratio
+ l_norm = (lpips_loss + mse_loss) * self.cfg.train.lambda_normal * decay_ratio
+ loss = loss + l_norm
+ #print('l_sil', l_sil.detach().item(), 'l_norm', l_norm.detach().item())
+ # if self.cfg.train.controlnet_guide_inputview:
+ # controlnet_hint = Image.fromarray(self.controlnet_annotator(self.input_image.permute(1, 2, 0))).resize((512,512))
+ # pred_rgb = outputs['image'].reshape(1, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
+ # loss = loss + self.guidance.train_step(text_z, pred_rgb, guidance_scale=self.cfg.train.guidance_scale, controlnet_conditioning_scale=self.cfg.train.controlnet_conditioning_scale, controlnet_hint=controlnet_hint,
+ # poses=data['poses'], text_embedding_novd=text_z_novd)
+ else:
+ outputs = self.model(rays_o, rays_d, mvp, H, W,
+ global_step=self.global_step,
+ mesh=step_mesh)
+ pred_rgb = outputs['image'].reshape(1, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
+ pred_alpha = outputs['alpha'].reshape(1, H, W, 1).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
+ mask = recon_mask.unsqueeze(0)
+ mask_edt = recon_mask_edt.unsqueeze(0)
+ if self.loss_mask is not None:
+ mask = mask * self.loss_mask
+ pred_alpha = pred_alpha * self.loss_mask
+ gt = self.input_image.unsqueeze(0)
+ if self.cfg.train.lambda_sil > 0:
+ loss = loss + silhouette_loss(pred_alpha.reshape(1, 1, H, W), mask.reshape(1, 1, H, W), edt=mask_edt, l2_weight=1., edge_weight=0.) * self.cfg.train.lambda_sil
+
+ if self.erosion_mask is not None:
+ mask = mask * self.erosion_mask
+ mse_loss = nn.functional.mse_loss(pred_rgb*mask, gt * mask) *0.2
+ random_color = torch.rand_like(gt[:1, :, :1, :1])
+ pred_rgb = pred_rgb * mask + random_color * (1-mask)
+ gt = gt * mask + random_color * (1-mask)
+ if self.cfg.train.crop_for_lpips: # to save memory
+ pred_rgb, _ = crop_by_mask(pred_rgb, mask)
+ gt, mask = crop_by_mask(gt, mask)
+ lpips_loss = self.lpips_model(scale_for_lpips(pred_rgb),
+ scale_for_lpips(gt))
+ loss = loss + (lpips_loss + mse_loss) * self.cfg.train.lambda_recon
+
+ if shading == 'albedo' and (not is_face) and self.cfg.train.lambda_color_chamfer > 0. and self.global_step >= self.cfg.train.color_chamfer_step:
+ h, w = output_images_novel.shape[-3], output_images_novel.shape[-2]
+ input_image = self.input_image.permute(1, 2, 0)
+ input_image = convert_rgb(input_image, self.cfg.train.color_chamfer_space)
+ output_images_novel = output_images_novel.reshape(h, w, 3)
+ output_images_novel = convert_rgb(output_images_novel, self.cfg.train.color_chamfer_space)
+ H, W = input_image.shape[-3], input_image.shape[-2]
+ input_pixels = input_image[self.input_mask.reshape(H, W) > 0.9].unsqueeze(0)
+ pred_pixels = output_images_novel.reshape(h, w, -1)[output_alpha_novel.reshape(h, w) > 0.9].unsqueeze(0)
+ loss = loss + chamfer_distance(input_pixels, pred_pixels, single_directional=self.cfg.train.single_directional_color_chamfer)[0] * self.cfg.train.lambda_color_chamfer
+
+ return pred_rgb, pred_depth, loss
+
+ def eval_step(self, data, no_resize=True):
+
+ is_face = data['is_face']
+ mvp = data['mvp']
+ if no_resize and not is_face:
+ mvp = mvp @ self.model.mesh.resize_matrix_inv
+ poses = data['poses']
+ H, W = data['H'], data['W']
+
+ rays = get_rays(data['poses'], data['intrinsics'], H, W, -1)
+ rays_o = rays['rays_o'] # [B, N, 3]
+ rays_d = rays['rays_d'] # [B, N, 3]
+
+ if self.cfg.train.normal_sample_ratio >= 1.:
+ shading = 'normal'
+ ambient_ratio = 0.1
+ elif self.cfg.train.textureless_sample_ratio >= 1.:
+ shading = 'textureless'
+ ambient_ratio = 0.1
+ else:
+ shading = data['shading'] if 'shading' in data else 'albedo'
+ ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
+ light_d = data['light_d'] if 'light_d' in data else None
+
+ outputs = self.model(rays_o, rays_d, mvp, H, W, poses=poses, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading,
+ global_step=self.global_step)
+ pred_rgb = outputs['image'].reshape(1, H, W, 3)
+ pred_depth = outputs['depth'].reshape(1, H, W)
+ outputs_normal = self.model(rays_o, rays_d, mvp, H, W, poses=poses, light_d=light_d, ambient_ratio=0.1, shading='normal',
+ global_step=self.global_step)
+ pred_norm = outputs_normal['image'].reshape(1, H, W, 3)
+ # dummy
+ loss = torch.zeros([1], device=pred_rgb.device, dtype=pred_rgb.dtype)
+
+ return pred_rgb, pred_depth, pred_norm, loss
+
+ def test_step(self, data, bg_color=None, perturb=False, mesh=None, can_pose=False, no_resize=False):
+ is_face = data['is_face']
+ mvp = data['mvp']
+ if no_resize and not is_face:
+ mvp = mvp @ self.model.mesh.resize_matrix_inv
+ poses = data['poses']
+ H, W = data['H'], data['W']
+
+ rays = get_rays(data['poses'], data['intrinsics'], H, W, -1)
+ rays_o = rays['rays_o'] # [B, N, 3]
+ rays_d = rays['rays_d'] # [B, N, 3]
+
+ if self.cfg.train.normal_sample_ratio >= 1:
+ shading = 'normal'
+ ambient_ratio = 0.1
+ elif self.cfg.train.textureless_sample_ratio >= 1:
+ shading = 'textureless'
+ ambient_ratio = 0.1
+ else:
+ shading = data['shading'] if 'shading' in data else 'albedo'
+ ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
+ light_d = data['light_d'] if 'light_d' in data else None
+
+
+
+ outputs = self.model(rays_o, rays_d, mvp, H, W, poses=poses, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, return_openpose_map=self.render_openpose,
+ global_step=self.global_step, mesh=mesh, can_pose=can_pose)
+
+ outputs_normal = self.model(rays_o, rays_d, mvp, H, W, poses=poses, light_d=light_d, ambient_ratio=0.1, shading='normal',
+ global_step=self.global_step, mesh=mesh, can_pose=can_pose)
+ pred_norm = outputs_normal['image'].reshape(1, H, W, 3)#[:, :, W//4: W//4 + W//2]
+
+ pred_rgb = outputs['image'].reshape(1, H, W, 3)#[:, :, W//4: W//4 + W//2]
+ pred_depth = outputs['depth'].reshape(1, H, W)#[:, :, W//4: W//4 + W//2]
+ pred_alpha = outputs['alpha'].reshape(1, H, W, 1)#[:, :, W//4: W//4 + W//2]
+ pred_mesh = outputs.get('mesh', None)
+ if self.render_openpose:
+ openpose_map = outputs['openpose_map'].reshape(1, H, W, 3)#[:, :, W//4: W//4 + W//2]
+ else:
+ openpose_map = None
+
+ return pred_rgb, pred_depth, pred_norm, pred_alpha, openpose_map, mesh
+
+ def save_mesh(self, save_path=None):
+
+ name = f'{self.cfg.sub_name}_{self.cfg.stage}'
+
+ if save_path is None:
+ save_path = os.path.join(self.cfg.exp_root, 'obj')
+
+ self.log(f"==> Saving mesh to {save_path}")
+
+ os.makedirs(save_path, exist_ok=True)
+
+ self.model.export_mesh(save_path, name=name, export_uv=self.cfg.test.save_uv)
+
+ self.log(f"==> Finished saving mesh.")
+
+ ### ------------------------------
+
+ def train(self, train_loader, valid_loader, max_epochs):
+
+ assert self.text_z is not None, 'Training must provide a text prompt!'
+
+ if self.use_tensorboardX and self.local_rank == 0:
+ self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))
+
+ start_t = time.time()
+
+ if self.epoch % self.eval_interval == 0:
+ self.evaluate_one_epoch(valid_loader)
+ self.save_checkpoint(full=False, best=True)
+ for epoch in range(self.epoch + 1, max_epochs + 1):
+ self.epoch = epoch
+
+ self.train_one_epoch(train_loader)
+ torch.cuda.empty_cache()
+ if self.workspace is not None and self.local_rank == 0:
+ self.save_checkpoint(full=True, best=False)
+
+ if self.epoch % self.eval_interval == 0:
+ self.evaluate_one_epoch(valid_loader)
+ self.save_checkpoint(full=False, best=True)
+
+ end_t = time.time()
+
+ self.log(f"[INFO] training takes {(end_t - start_t)/ 60:.4f} minutes.")
+
+ if self.use_tensorboardX and self.local_rank == 0:
+ self.writer.close()
+
+ def evaluate(self, loader, name=None):
+ self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
+ self.evaluate_one_epoch(loader, name)
+ self.use_tensorboardX = use_tensorboardX
+
+ def test(self, loader, save_path=None, name=None, write_video=True, can_pose=False, write_image=False):
+
+ if save_path is None:
+ save_path = os.path.join(self.workspace, 'visualize')
+
+ if name is None:
+ name = f'{self.workspace.split("/")[-1]}_ep{self.epoch:04d}'
+ if can_pose:
+ name = name + '_can_pose'
+
+ os.makedirs(save_path, exist_ok=True)
+
+ self.log(f"==> Start Test, save results to {save_path}")
+
+ pbar = tqdm.tqdm(
+ total=len(loader) * loader.batch_size,
+ bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
+ self.model.eval()
+
+ if write_video:
+ all_preds = []
+ all_preds_depth = []
+ all_preds_norm = []
+ all_openpose_map = []
+
+ with torch.no_grad():
+ mesh = None
+ for i, data in enumerate(loader):
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ preds, preds_depth, preds_norm, preds_alpha, openpose_map, pred_mesh = self.test_step(data, mesh=mesh, can_pose=can_pose, no_resize=not can_pose)
+ if mesh is None:
+ mesh = pred_mesh
+ preds_alpha = preds_alpha[0].detach().cpu().numpy()
+
+ pred = preds[0].detach().cpu().numpy()
+ #pred = (pred * 255).astype(np.uint8)
+ pred = ((pred * preds_alpha + (1-preds_alpha))* 255).astype(np.uint8)
+
+ pred_norm = preds_norm[0].detach().cpu().numpy()
+ pred_norm = ((pred_norm * preds_alpha + (1-preds_alpha)) * 255).astype(np.uint8)
+
+ pred_depth = preds_depth[0].detach().cpu().numpy()
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6)
+ pred_depth = (pred_depth * 255).astype(np.uint8)
+ if self.render_openpose:
+
+ openpose_map = (openpose_map[0].detach().cpu().numpy() * 255).astype(np.uint8)
+
+ if write_video:
+ all_preds.append(pred)
+ all_preds_depth.append(pred_depth)
+ all_preds_norm.append(pred_norm)
+ if self.render_openpose:
+ all_openpose_map.append(openpose_map)
+ if write_image and i % 10 == 0:
+ if isinstance(preds_alpha, torch.Tensor):
+ preds_alpha = preds_alpha[0].detach().cpu().numpy()
+ preds_alpha = (preds_alpha * 255).astype(np.uint8)
+ pred = np.concatenate([pred, preds_alpha], axis=-1)
+ pred_norm = np.concatenate([pred_norm, preds_alpha], axis=-1)
+ cv2.imwrite(
+ os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGBA2BGRA))
+ cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth)
+ cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_norm.png'), cv2.cvtColor(pred_norm, cv2.COLOR_RGBA2BGRA))
+
+ pbar.update(loader.batch_size)
+
+ if write_video:
+ all_preds = np.stack(all_preds, axis=0)
+ all_preds_depth = np.stack(all_preds_depth, axis=0)
+ all_preds_norm = np.stack(all_preds_norm, axis=0)
+ all_preds_full = np.concatenate(
+ [
+ np.concatenate([all_preds[:100], all_preds_norm[:100]], axis=2),
+ np.concatenate([all_preds[100:], all_preds_norm[100:]], axis=2),
+ ], axis=2
+ )
+ if self.cfg.stage == 'texture':
+ imageio.mimwrite(
+ os.path.join(save_path, f'{name}_rgb.mp4'), all_preds[:100], fps=25, quality=8, macro_block_size=1)
+ imageio.mimwrite(
+ os.path.join(save_path, f'{name}_full.mp4'), all_preds_full, fps=25, quality=8, macro_block_size=1)
+ imageio.mimwrite(
+ os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth[:100], fps=25, quality=8, macro_block_size=1)
+ imageio.mimwrite(
+ os.path.join(save_path, f'{name}_norm.mp4'), all_preds_norm[:100], fps=25, quality=8, macro_block_size=1)
+ if self.render_openpose:
+ all_openpose_map = np.stack(all_openpose_map, axis=0)
+ imageio.mimwrite(
+ os.path.join(save_path, f'{name}_openpose.mp4'), all_openpose_map, fps=25, quality=8, macro_block_size=1)
+
+ self.log(f"==> Finished Test.")
+
+
+ def train_one_epoch(self, loader):
+ self.log(
+ f"==> Start Training {self.workspace} Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ..."
+ )
+
+ total_loss = 0
+ if self.local_rank == 0 and self.report_metric_at_train:
+ for metric in self.metrics:
+ metric.clear()
+
+ self.model.train()
+
+ # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
+ # ref: https://pytorch.org/docs/stable/data.html
+ if self.world_size > 1:
+ loader.sampler.set_epoch(self.epoch)
+
+ if self.local_rank == 0:
+ pbar = tqdm.tqdm(
+ total=len(loader) * loader.batch_size,
+ bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
+
+ self.local_step = 0
+
+ for data in loader:
+
+ self.local_step += 1
+ self.global_step += 1
+
+ self.optimizer.zero_grad()
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ pred_rgbs, pred_depths, loss = self.train_step(data)
+
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+
+ if self.scheduler_update_every_step:
+ self.lr_scheduler.step()
+
+ loss_val = loss.item()
+ total_loss += loss_val
+
+ if self.local_rank == 0:
+ if self.use_tensorboardX:
+ self.writer.add_scalar("train/loss", loss_val, self.global_step)
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
+
+ if self.scheduler_update_every_step:
+ pbar.set_description(
+ f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}"
+ )
+ else:
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
+ pbar.update(loader.batch_size)
+
+ if self.ema is not None:
+ self.ema.update()
+
+ average_loss = total_loss / self.local_step
+ self.stats["loss"].append(average_loss)
+
+ if self.local_rank == 0:
+ pbar.close()
+ if self.report_metric_at_train:
+ for metric in self.metrics:
+ self.log(metric.report(), style="red")
+ if self.use_tensorboardX:
+ metric.write(self.writer, self.epoch, prefix="train")
+ metric.clear()
+
+ if not self.scheduler_update_every_step:
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step(average_loss)
+ else:
+ self.lr_scheduler.step()
+
+ self.log(f"==> Finished Epoch {self.epoch}.")
+
+ def evaluate_one_epoch(self, loader, name=None):
+ self.log(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...")
+
+ if name is None:
+ name = f'{self.name}_ep{self.epoch:04d}'
+
+ total_loss = 0
+ if self.local_rank == 0:
+ for metric in self.metrics:
+ metric.clear()
+
+ self.model.eval()
+
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ if self.local_rank == 0:
+ pbar = tqdm.tqdm(
+ total=len(loader) * loader.batch_size,
+ bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
+
+ with torch.no_grad():
+ self.local_step = 0
+
+ for data in loader:
+ self.local_step += 1
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ preds, preds_depth, preds_normal, loss = self.eval_step(data)
+
+ # all_gather/reduce the statistics (NCCL only support all_*)
+ if self.world_size > 1:
+ dist.all_reduce(loss, op=dist.ReduceOp.SUM)
+ loss = loss / self.world_size
+
+ preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)
+ ] # [[B, ...], [B, ...], ...]
+ dist.all_gather(preds_list, preds)
+ preds = torch.cat(preds_list, dim=0)
+
+ preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)
+ ] # [[B, ...], [B, ...], ...]
+ dist.all_gather(preds_depth_list, preds_depth)
+ preds_depth = torch.cat(preds_depth_list, dim=0)
+
+ preds_normal_list = [torch.zeros_like(preds_normal).to(self.device) for _ in range(self.world_size)
+ ] # [[B, ...], [B, ...], ...]
+ dist.all_gather(preds_normal_list, preds_normal)
+ preds_normal = torch.cat(preds_normal_list, dim=0)
+ loss_val = loss.item()
+ total_loss += loss_val
+
+ # only rank = 0 will perform evaluation.
+ if self.local_rank == 0:
+
+ # save image
+ save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
+ save_path_normal = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_normal.png')
+ save_path_depth = os.path.join(self.workspace, 'validation',
+ f'{name}_{self.local_step:04d}_depth.png')
+
+ #self.log(f"==> Saving validation image to {save_path}")
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ pred = preds[0].detach().cpu().numpy()
+ pred = (pred * 255).astype(np.uint8)
+
+ pred_depth = preds_depth[0].detach().cpu().numpy()
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6)
+ pred_depth = (pred_depth * 255).astype(np.uint8)
+
+ pred_normal = preds_normal[0].detach().cpu().numpy()
+ pred_normal = (pred_normal * 255).astype(np.uint8)
+
+ cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
+ cv2.imwrite(save_path_depth, pred_depth)
+ cv2.imwrite(save_path_normal, cv2.cvtColor(pred_normal, cv2.COLOR_RGB2BGR))
+
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
+ pbar.update(loader.batch_size)
+
+ average_loss = total_loss / self.local_step
+ self.stats["valid_loss"].append(average_loss)
+
+ if self.local_rank == 0:
+ pbar.close()
+ if not self.use_loss_as_metric and len(self.metrics) > 0:
+ result = self.metrics[0].measure()
+ self.stats["results"].append(result if self.best_mode == 'min' else -result) # if max mode, use -result
+ else:
+ self.stats["results"].append(average_loss) # if no metric, choose best by min loss
+
+ for metric in self.metrics:
+ self.log(metric.report(), style="blue")
+ if self.use_tensorboardX:
+ metric.write(self.writer, self.epoch, prefix="evaluate")
+ metric.clear()
+
+ if self.ema is not None:
+ self.ema.restore()
+
+ self.log(f"++> Evaluate epoch {self.epoch} Finished.")
+
+ def save_checkpoint(self, name=None, full=False, best=False):
+
+ if name is None:
+ name = f'{self.name}_ep{self.epoch:04d}'
+
+ state = {
+ 'epoch': self.epoch,
+ 'global_step': self.global_step,
+ 'stats': self.stats,
+ }
+
+ if full:
+ state['optimizer'] = self.optimizer.state_dict()
+ state['lr_scheduler'] = self.lr_scheduler.state_dict()
+ state['scaler'] = self.scaler.state_dict()
+ if self.ema is not None:
+ state['ema'] = self.ema.state_dict()
+
+ if not best:
+
+ state['model'] = self.model.state_dict()
+
+ file_path = f"{name}.pth"
+
+ self.stats["checkpoints"].append(file_path)
+
+ if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
+ old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0))
+ if os.path.exists(old_ckpt):
+ os.remove(old_ckpt)
+
+ torch.save(state, os.path.join(self.ckpt_path, file_path))
+
+ else:
+ if len(self.stats["results"]) > 0:
+ # always save best since loss cannot reflect performance.
+ if True:
+ # self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
+ # self.stats["best_result"] = self.stats["results"][-1]
+
+ # save ema results
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ state['model'] = self.model.state_dict()
+
+ if self.ema is not None:
+ self.ema.restore()
+
+ torch.save(state, self.best_path)
+ else:
+ self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")
+
+ def load_pretrained(self, pretrained=None):
+ if pretrained is None:
+ return
+ else:
+ self.log("[INFO] loading pretrained model from {}".format(pretrained))
+ checkpoint_dict = torch.load(pretrained, map_location=self.device)
+ if 'model' in checkpoint_dict:
+ checkpoint_dict = checkpoint_dict['model']
+ if 'v_offsets' in checkpoint_dict:
+ checkpoint_dict.pop('v_offsets')
+ if 'vn_offsets' in checkpoint_dict:
+ checkpoint_dict.pop('vn_offsets')
+ if 'sdf' in checkpoint_dict:
+ checkpoint_dict.pop('sdf')
+ missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict, strict=False)
+ self.log("[INFO] loaded model.")
+ if len(missing_keys) > 0:
+ self.log(f"[WARN] missing keys: {missing_keys}")
+ if len(unexpected_keys) > 0:
+ self.log(f"[WARN] unexpected keys: {unexpected_keys}")
+
+ def load_checkpoint(self, checkpoint=None, model_only=False):
+ if checkpoint is None:
+ checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth'))
+ if checkpoint_list:
+ checkpoint = checkpoint_list[-1]
+ self.log(f"[INFO] Latest checkpoint is {checkpoint}")
+ else:
+ self.log("[WARN] No checkpoint found, model randomly initialized.")
+ return
+
+ checkpoint_dict = torch.load(checkpoint, map_location=self.device)
+
+ if 'model' not in checkpoint_dict:
+ self.model.load_state_dict(checkpoint_dict)
+ self.log("[INFO] loaded model.")
+ return
+
+ missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
+ self.log("[INFO] loaded model.")
+ if len(missing_keys) > 0:
+ self.log(f"[WARN] missing keys: {missing_keys}")
+ if len(unexpected_keys) > 0:
+ self.log(f"[WARN] unexpected keys: {unexpected_keys}")
+
+ if self.ema is not None and 'ema' in checkpoint_dict:
+ try:
+ self.ema.load_state_dict(checkpoint_dict['ema'])
+ self.log("[INFO] loaded EMA.")
+ except:
+ self.log("[WARN] failed to loaded EMA.")
+
+ if model_only:
+ return
+
+ self.stats = checkpoint_dict['stats']
+ self.epoch = checkpoint_dict['epoch']
+ self.global_step = checkpoint_dict['global_step']
+ self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
+
+ if self.optimizer and 'optimizer' in checkpoint_dict:
+ try:
+ self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
+ self.log("[INFO] loaded optimizer.")
+ except:
+ self.log("[WARN] Failed to load optimizer.")
+
+ if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
+ try:
+ self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
+ self.log("[INFO] loaded scheduler.")
+ except:
+ self.log("[WARN] Failed to load scheduler.")
+
+ if self.scaler and 'scaler' in checkpoint_dict:
+ try:
+ self.scaler.load_state_dict(checkpoint_dict['scaler'])
+ self.log("[INFO] loaded scaler.")
+ except:
+ self.log("[WARN] Failed to load scaler.")
\ No newline at end of file
diff --git a/core/lib/uv_utils.py b/core/lib/uv_utils.py
new file mode 100755
index 0000000..15a5a93
--- /dev/null
+++ b/core/lib/uv_utils.py
@@ -0,0 +1,107 @@
+import numpy as np
+import cv2
+def diffuse_color_with_mask(img_m, img_c, num_iter=1, ksize=3):
+ """
+ cv.findContours: http://t.zoukankan.com/wojianxin-p-12602490.html
+ """
+ img_m[img_m != 0] = 255
+
+ hksize = ksize // 2
+ k_range = range(-hksize, hksize + 1)
+
+ #* expand
+ img_m = cv2.copyMakeBorder(img_m, hksize, hksize, hksize, hksize, cv2.BORDER_CONSTANT, value=(0))
+ img_c = cv2.copyMakeBorder(img_c, hksize, hksize, hksize, hksize, cv2.BORDER_CONSTANT, value=(0, 0, 0))
+
+ for _ in range(num_iter):
+ uu, vv = np.where(img_m == 0)
+
+ #* remove border
+ m = True
+ m &= (uu >= hksize)
+ m &= (uu < img_m.shape[0] - hksize)
+ m &= (vv >= hksize)
+ m &= (vv < img_m.shape[1] - hksize)
+ uu = uu[m]
+ vv = vv[m]
+
+ #* select silhouette. only 3x3 patch
+ m = False
+ for tu in [-1, 0, 1]:
+ for tv in [-1, 0, 1]:
+ m |= (img_m[uu + tu, vv + tv] == 255)
+
+ uu = uu[m]
+ vv = vv[m]
+ img_m[uu, vv] = 127 #! set silhouette value
+
+ #* calc weights: 0/1 | sum and mean
+ c = 0
+ w = 0
+ for tu in k_range:
+ for tv in k_range:
+ tw = (img_m[uu + tu, vv + tv] == 255).astype(np.float32).reshape(-1, 1)
+ tc = (img_c[uu + tu, vv + tv]).astype(np.float32)
+ w += tw
+ c += tw * tc
+ img_c[uu, vv] = (c / w).astype(np.float32)
+ img_m[img_m == 127] = 255 #!
+
+ img_m = img_m[hksize:-hksize, hksize:-hksize]
+ img_c = img_c[hksize:-hksize, hksize:-hksize].astype(np.uint8)
+
+ return img_m, img_c
+
+
+def texture_padding(img_c0, img_m0, fac=1.25):
+ """
+ * question: https://blender.stackexchange.com/a/265246/82691
+ Here are some related keywords/links:
+ [Texture Padding](https://www.youtube.com/watch?v=MVsIIkJNkjM&ab_channel=malcolm341),
+ `Solidify` in [Free Plug-ins](http://www.flamingpear.com/free-trials.html)
+ and [Seam Fixing](https://www.youtube.com/watch?v=r9l8RfTvqyI&ab_channel=NamiNaeko);
+ [TexTools](https://github.com/SavMartin/TexTools-Blender) for Blender.
+ * reference:
+ [inpainting for atlas/texture map](https://blender.stackexchange.com/questions/264966/inpainting-for-atlas-texture-map)
+ [mipmap](https://substance3d.adobe.com/documentation/spdoc/padding-134643719.html)
+ [distance transform](https://stackoverflow.com/questions/26421566/pixel-indexing-in-opencvs-distance-transform)
+ [seamlessClone](https://learnopencv.com/seamless-cloning-using-opencv-python-cpp/)
+ [torch-interpol](https://github.com/balbasty/torch-interpol/issues/1)
+ """
+
+ assert 1 < fac < 1.5
+
+ if np.all(img_m0 > 0):
+ return img_c0
+
+ img_m0[img_m0 != 0] = 255
+
+ img_m0, img_c0 = diffuse_color_with_mask(img_m0, img_c0, 2) #* diffuse 2 pixels (2x2 downsampling)
+
+ img_m1 = img_m0.copy()
+ img_c1 = img_c0.copy()
+ while np.any(img_m1 == 0):
+ img_m1 = cv2.resize(img_m1, (int(img_m1.shape[0] / fac), int(img_m1.shape[1] / fac)), interpolation=cv2.INTER_LINEAR)
+ img_c1 = cv2.resize(img_c1, (int(img_c1.shape[0] / fac), int(img_c1.shape[1] / fac)), interpolation=cv2.INTER_LINEAR)
+ img_m1[img_m1 != 255] = 0
+ img_c1[img_m1 == 0] = 0
+ img_m1, img_c1 = diffuse_color_with_mask(img_m1, img_c1, 2)
+
+ img_m2 = img_m1.copy()
+ img_c2 = img_c1.copy()
+ while img_m2.shape[0] != img_m0.shape[0]:
+ if (img_m0.shape[0] < img_m2.shape[0] * fac < img_m0.shape[0] * fac):
+ img_shape = (img_m0.shape[0], img_m0.shape[1])
+ else:
+ img_shape = (int(img_m2.shape[0] * fac), int(img_m2.shape[1] * fac))
+ img_m2 = cv2.resize(img_m2, img_shape, interpolation=cv2.INTER_LINEAR)
+ img_c2 = cv2.resize(img_c2, img_shape, interpolation=cv2.INTER_LINEAR)
+
+ img_m2[img_m2 != 255] = 0
+ img_c2[img_m2 == 0] = 0
+
+ nnz = np.nonzero(~img_m0 & img_m2)
+ img_c0[nnz] = img_c2[nnz]
+ img_m0 = img_m2
+
+ return img_c0
\ No newline at end of file
diff --git a/core/main.py b/core/main.py
new file mode 100755
index 0000000..328cbad
--- /dev/null
+++ b/core/main.py
@@ -0,0 +1,148 @@
+#import nvdiffrast.torch as dr
+import torch
+import argparse
+
+from lib.provider import ViewDataset
+from lib.trainer import *
+from lib.renderer import Renderer
+
+from yacs.config import CfgNode as CN
+
+
+def load_config(path, default_path=None):
+ cfg = CN(new_allowed=True)
+ if default_path is not None:
+ cfg.merge_from_file(default_path)
+ cfg.merge_from_file(path)
+
+ return cfg
+#torch.autograd.set_detect_anomaly(True)
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True, help="config file")
+ parser.add_argument('--exp_dir', type=str, required=True, help="experiment dir")
+ parser.add_argument('--sub_name', type=str, required=True, help="subject name")
+ parser.add_argument('--seed', type=int, default=42, help="random seed")
+ parser.add_argument('--test', action="store_true")
+
+
+ opt = parser.parse_args()
+ cfg = load_config(opt.config, default_path="configs/default.yaml")
+ cfg.test.test = opt.test
+ cfg.workspace = os.path.join(opt.exp_dir, cfg.stage)
+ cfg.exp_root = opt.exp_dir
+ cfg.sub_name = opt.sub_name
+ if cfg.data.load_input_image:
+ cfg.data.img = os.path.join(opt.exp_dir, 'png', "{}_crop.png".format(opt.sub_name))
+ if cfg.data.load_front_normal:
+ cfg.data.front_normal_img = os.path.join(opt.exp_dir, 'normal', "{}_normal_front.png".format(opt.sub_name))
+ if cfg.data.load_back_normal:
+ cfg.data.back_normal_img = os.path.join(opt.exp_dir, 'normal', "{}_normal_back.png".format(opt.sub_name))
+ if cfg.data.load_keypoints:
+ cfg.data.keypoints_path = os.path.join(opt.exp_dir, 'obj', "{}_smpl.npy".format(opt.sub_name))
+ if cfg.data.load_result_mesh:
+ cfg.data.last_model = os.path.join(opt.exp_dir, 'obj', "{}_pose.obj".format(opt.sub_name))
+ cfg.data.last_ref_model = os.path.join(opt.exp_dir, 'obj', "{}_smpl.obj".format(opt.sub_name))
+ else:
+ cfg.data.last_model = os.path.join(opt.exp_dir, 'obj', "{}_smpl.obj".format(opt.sub_name))
+ if cfg.data.load_apose_mesh:
+ cfg.data.can_pose_folder = os.path.join(opt.exp_dir, 'obj', "{}_apose.obj".format(opt.sub_name))
+ if cfg.data.load_apose_mesh:
+ cfg.data.can_pose_folder = os.path.join(opt.exp_dir, 'obj', "{}_apose.obj".format(opt.sub_name))
+ if cfg.data.load_occ_mask:
+ cfg.data.occ_mask = os.path.join(opt.exp_dir, 'png', "{}_occ_mask.png".format(opt.sub_name))
+ if cfg.data.load_da_pose_mesh:
+ cfg.data.da_pose_mesh = os.path.join(opt.exp_dir, 'obj', "{}_da_pose.obj".format(opt.sub_name))
+ if cfg.guidance.use_dreambooth:
+ cfg.guidance.hf_key = os.path.join(opt.exp_dir, 'sd_model')
+ if cfg.guidance.text is None:
+ with open(os.path.join(opt.exp_dir, 'prompt.txt'), 'r') as f:
+ cfg.guidance.text = f.readlines()[0].split('|')[0]
+
+ print(cfg)
+
+ seed_everything(opt.seed)
+ model = Renderer(cfg)
+ if model.keypoints is not None:
+ if len(model.keypoints[0]) == 1:
+ cfg.train.head_position = model.keypoints[0][0].cpu().numpy().tolist()
+ else:
+ cfg.train.head_position = model.keypoints[0][15].cpu().numpy().tolist()
+ else:
+ cfg.train.head_position = np.array([0., 0.4, 0.], dtype=np.float32).tolist()
+ cfg.train.canpose_head_position = np.array([0., 0.4, 0.], dtype=np.float32).tolist()
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ if cfg.test.test:
+ guidance = None # no need to load guidance model at test
+ trainer = Trainer(
+ 'df', cfg, model, guidance, device=device, workspace=cfg.workspace, fp16=cfg.fp16, use_checkpoint=cfg.train.ckpt, pretrained=cfg.train.pretrained)
+
+ if not cfg.test.not_test_video:
+ test_loader = ViewDataset(cfg, device=device, type='test', H=cfg.test.H, W=cfg.test.W, size=100, render_head=True).dataloader()
+ trainer.test(test_loader, write_image=cfg.test.write_image)
+ if cfg.data.can_pose_folder is not None:
+ trainer.test(test_loader, write_image=cfg.test.write_image, can_pose=True)
+ if cfg.test.save_mesh:
+ trainer.save_mesh()
+ else:
+
+ train_loader = ViewDataset(cfg, device=device, type='train', H=cfg.train.h, W=cfg.train.w, size=100).dataloader()
+ params_list = list()
+ if cfg.guidance.type == 'stable-diffusion':
+ from lib.guidance import StableDiffusion
+ guidance = StableDiffusion(device, cfg.guidance.sd_version, cfg.guidance.hf_key, cfg.guidance.step_range, controlnet=cfg.guidance.controlnet, lora=cfg.guidance.lora, cfg=cfg, head_hf_key=cfg.guidance.head_hf_key)
+ for p in guidance.parameters():
+ p.requires_grad = False
+ else:
+ raise NotImplementedError(f'--guidance {cfg.guidance.type} is not implemented.')
+
+ if cfg.train.optim == 'adan':
+ from lib.optimizer import Adan
+ # Adan usually requires a larger LR
+ params_list.extend(model.get_params(5 * cfg.train.lr))
+ optimizer = lambda model: Adan(
+ params_list, eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)
+ else: # adam
+ params_list.extend(model.get_params(cfg.train.lr))
+ optimizer = lambda model: torch.optim.Adam(params_list, betas=(0.9, 0.99), eps=1e-15)
+
+ # scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed
+ scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1**min(iter / cfg.train.iters, 1))
+
+ trainer = Trainer(
+ 'df',
+ cfg,
+ model,
+ guidance,
+ device=device,
+ workspace=cfg.workspace,
+ optimizer=optimizer,
+ ema_decay=None,
+ fp16=cfg.train.fp16,
+ lr_scheduler=scheduler,
+ use_checkpoint=cfg.train.ckpt,
+ eval_interval=cfg.train.eval_interval,
+ scheduler_update_every_step=True,
+ pretrained=cfg.train.pretrained)
+
+ valid_loader = ViewDataset(cfg, device=device, type='val', H=cfg.test.H, W=cfg.test.W, size=5).dataloader()
+
+ max_epoch = np.ceil(cfg.train.iters / len(train_loader)).astype(np.int32)
+ if cfg.profile:
+ import cProfile
+ with cProfile.Profile() as pr:
+ trainer.train(train_loader, valid_loader, max_epoch)
+ pr.dump_stats(os.path.join(cfg.workspace, 'profile.dmp'))
+ pr.print_stats()
+ else:
+ trainer.train(train_loader, valid_loader, max_epoch)
+
+ test_loader = ViewDataset(cfg, device=device, type='test', H=cfg.test.H, W=cfg.test.W, size=100, render_head=True).dataloader()
+ trainer.test(test_loader, write_image=cfg.test.write_image)
+
+ if cfg.test.save_mesh:
+ trainer.save_mesh()
\ No newline at end of file
diff --git a/docs/install.md b/docs/install.md
new file mode 100644
index 0000000..4d20e55
--- /dev/null
+++ b/docs/install.md
@@ -0,0 +1,32 @@
+## Environment setup
+
+1. We have tested our code with this docker environment `pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel` and NVIDIA V100 GPUs.
+2. Install PyTorch: `pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0`
+3. Install other dependencies:
+```sh
+# install libraries
+apt-get install -y \
+ libglfw3-dev \
+ libgles2-mesa-dev \
+ libglib2.0-0 \
+ libosmesa6-dev \
+# install requirements
+pip install -r requirements.txt
+# install kaolin
+pip install kaolin==0.13.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-${YOUR_TORCH_VERSION}_${YOUR_CUDA_VERSION}.html
+```
+4. Build modules
+```sh
+cd core/lib/freqencoder
+python setup.py install
+cd ../gridencoder
+python setup.py install
+cd ../../
+```
+5. Fetch third-partiy code:
+```sh
+git clone https://github.com/ZHKKKe/MODNet thirdparties/MODNet
+```
+1. Download necessary data for body models: `sh scripts/download_body_data.sh`
+2. Download `runwayml/stable-diffusion-v1-5` checkpoint, background images and class regularization data for DreamBooth by running `sh scripts/download_dreambooth_data.sh`, you can also try using another version of SD model, or use other images of `man` and `woman` for regularization (We simply generates these data with the SD model).
+
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100755
index 0000000..8f9830e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,61 @@
+accelerate==0.17.1
+albumentations==1.1.0
+boto3
+chumpy
+dataclasses
+dearpygui==1.8.0
+diffusers==0.15.0
+einops
+einops==0.4.1
+imageio-ffmpeg==0.4.8
+imageio==2.26.0
+kornia==0.6
+kornia>0.4.0
+lpips==0.1.4
+matplotlib==3.7.1
+mediapipe
+mesh-to-sdf==0.0.14
+ninja==1.11.1
+numpy==1.24.3
+omegaconf==2.1.1
+open3d
+opencv_contrib_python
+opencv-python==4.7.0.72
+packaging==23.0
+pandas==1.5.3
+pillow==9.0.1
+protobuf
+pudb==2019.2
+pyfqmr
+PyMCubes==0.1.4
+pymeshfix==0.16.2
+pymeshlab==2022.2.post3
+PyOpenGL==3.1.5
+pyrender==0.1.45
+pyrr==0.10.3
+pytorch-lightning==1.9.1
+pyvista==0.38.5
+replicate
+rich==13.3.2
+rtree
+scikit-image
+scikit-learn==1.2.2
+scipy==1.9.1
+setuptools==59.5.0
+streamlit>=0.73.1
+tensorboardX==2.6
+termcolor
+test-tube>=0.7.5
+tetgen==0.6.2
+torch-ema==0.3
+torch-fidelity==0.3.0
+torchmetrics==0.6.0
+tqdm==4.62.3
+transformers==4.27.1
+trimesh==3.20.2
+xatlas==0.0.7
+yacs==0.1.8
+git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+git+https://github.com/openai/CLIP.git@main#egg=clip
+git+https://github.com/facebookresearch/pytorch3d.git@v0.7.1
+# pip install kaolin==0.13.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.13.0_cu116.html
\ No newline at end of file
diff --git a/scripts/download_body_data.sh b/scripts/download_body_data.sh
new file mode 100644
index 0000000..f140a77
--- /dev/null
+++ b/scripts/download_body_data.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
+
+mkdir -p data/body_data/smpl_related/models
+
+# username and password input
+echo -e "\nYou need to register at https://icon.is.tue.mpg.de/, according to Installation Instruction."
+read -p "Username (ICON):" username
+read -p "Password (ICON):" password
+username=$(urle $username)
+password=$(urle $password)
+
+# SMPL (Male, Female)
+echo -e "\nDownloading SMPL..."
+wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.0.0.zip&resume=1' -O './data/body_data/smpl_related/models/SMPL_python_v.1.0.0.zip' --no-check-certificate --continue
+unzip data/body_data/smpl_related/models/SMPL_python_v.1.0.0.zip -d data/body_data/smpl_related/models
+mv data/body_data/smpl_related/models/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl data/body_data/smpl_related/models/smpl/SMPL_FEMALE.pkl
+mv data/body_data/smpl_related/models/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl data/body_data/smpl_related/models/smpl/SMPL_MALE.pkl
+cd data/body_data/smpl_related/models
+rm -rf *.zip __MACOSX smpl/models smpl/smpl_webuser
+cd ../../..
+
+# SMPL (Neutral, from SMPLIFY)
+echo -e "\nDownloading SMPLify..."
+wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplify&sfile=mpips_smplify_public_v2.zip&resume=1' -O './data/body_data/smpl_related/models/mpips_smplify_public_v2.zip' --no-check-certificate --continue
+unzip data/body_data/smpl_related/models/mpips_smplify_public_v2.zip -d data/body_data/smpl_related/models
+mv data/body_data/smpl_related/models/smplify_public/code/models/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl data/body_data/smpl_related/models/smpl/SMPL_NEUTRAL.pkl
+cd data/body_data/smpl_related/models
+rm -rf *.zip smplify_public
+cd ../../..
+
+# SMPL-X
+echo -e "\nDownloading SMPL-X..."
+wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip&resume=1' -O './data/body_data/smpl_related/models/models_smplx_v1_1.zip' --no-check-certificate --continue
+unzip data/body_data/smpl_related/models/models_smplx_v1_1.zip -d data/body_data/smpl_related
+rm -f data/body_data/smpl_related/models/models_smplx_v1_1.zip
+
+# ECON
+echo -e "\nDownloading ECON..."
+wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=icon&sfile=econ_data.zip&resume=1' -O './data/body_data/econ_data.zip' --no-check-certificate --continue
+cd data && unzip econ_data.zip
+mv smpl_data smpl_related/
+rm -f econ_data.zip
+cd ..
+
+mkdir -p data/body_data/HPS
+
+# PIXIE
+echo -e "\nDownloading PIXIE..."
+wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=icon&sfile=HPS/pixie_data.zip&resume=1' -O './data/body_data/HPS/pixie_data.zip' --no-check-certificate --continue
+cd data/body_data/HPS && unzip pixie_data.zip
+rm -f pixie_data.zip
+cd ../..
+
+# PyMAF-X
+# echo -e "\nDownloading PyMAF-X..."
+# wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=icon&sfile=HPS/pymafx_data.zip&resume=1' -O './data/body_data/HPS/pymafx_data.zip' --no-check-certificate --continue
+# cd data/body_data/HPS && unzip pymafx_data.zip
+# rm -f pymafx_data.zip
+# cd ../..
\ No newline at end of file
diff --git a/scripts/download_dreambooth_data.sh b/scripts/download_dreambooth_data.sh
new file mode 100644
index 0000000..3a15155
--- /dev/null
+++ b/scripts/download_dreambooth_data.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+mkdir -p data/dreambooth_data
+
+# SD v1-5 LDM checkpoint
+echo -e "\nDownloading stable diffusion v1.5..."
+wget 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt' -O data/dreambooth_data/v1-5-pruned.ckpt
+
+# ECON
+echo -e "\nDownloading dreambooth background images and regularization images..."
+wget 'https://www.dropbox.com/scl/fi/ucj961vt90hix12up2nyv/dreambooth_data.zip?rlkey=w1frc8hzkjskmnesokextp84r&dl=0' -O 'data/dreambooth_data/dreambooth_data.zip' --no-check-certificate --continue
+cd data/dreambooth_data && unzip data/dreambooth_data/dreambooth_data.zip
+rm -f dreambooth_data.zip
\ No newline at end of file
diff --git a/scripts/run.sh b/scripts/run.sh
new file mode 100755
index 0000000..0bc749c
--- /dev/null
+++ b/scripts/run.sh
@@ -0,0 +1,36 @@
+set -x
+export INPUT_FILE=$1;
+export EXP_DIR=$2;
+export SUBJECT_NAME=$(basename $1 | cut -d"." -f1);
+export REPLICATE_API_TOKEN=""; # your replicate token for BLIP API
+export CUDA_HOME=/usr/local/cuda-11.6/ #/your/cuda/home/dir;
+export PYOPENGL_PLATFORM=osmesa
+export MESA_GL_VERSION_OVERRIDE=4.1
+export PYTHONPATH=$PYTHONPATH:$(pwd);
+
+# Step 1: Preprocess image, get SMPL-X & normal estimation
+python utils/body_utils/preprocess.py --in_path ${INPUT_FILE} --out_dir ${EXP_DIR}
+
+# Step 2: Get BLIP prompt and gender, you can also use your own prompt
+python utils/get_prompt_blip.py --img-path ${EXP_DIR}/png/${SUBJECT_NAME}_crop.png --out-path ${EXP_DIR}/prompt.txt
+# python core/get_prompt.py ${EXP_DIR}/png/${SUBJECT_NAME}_crop.png
+export PROMPT=`cat ${EXP_DIR}/prompt.txt| cut -d'|' -f1`
+export GENDER=`cat ${EXP_DIR}/prompt.txt| cut -d'|' -f2`
+
+# Step 3: Finetune Dreambooth model (minimal GPU memory requirement: 2x32G)
+rm -rf ${EXP_DIR}/ldm
+python utils/ldm_utils/main.py -t --data_root ${EXP_DIR}/png/ --logdir ${EXP_DIR}/ldm/ --reg_data_root data/dreambooth_data/class_${GENDER}_images/ --bg_root data/dreambooth_data/bg_images/ --class_word ${GENDER} --no-test --gpus 0,1
+# Convert Dreambooth model to diffusers format
+python utils/ldm_utils/convert_ldm_to_diffusers.py --checkpoint_path ${EXP_DIR}/ldm/_v1-finetune_unfrozen/checkpoints/last.ckpt --original_config_file utils/ldm_utils/configs/stable-diffusion/v1-inference.yaml --scheduler_type ddim --image_size 512 --prediction_type epsilon --dump_path ${EXP_DIR}/sd_model
+# [Optional] you can delete the original ldm exp dir to save disk memory
+rm -rf ${EXP_DIR}/ldm
+
+# Step 4: Run geometry stage (Run on a single GPU)
+python core/main.py --config configs/tech_geometry.yaml --exp_dir $EXP_DIR --sub_name $SUBJECT_NAME
+python utils/body_utils/postprocess.py --dir $EXP_DIR/obj --name $SUBJECT_NAME
+
+# Step 5: Run texture stage (Run on a single GPU)
+python core/main.py --config configs/tech_texture.yaml --exp_dir $EXP_DIR --sub_name $SUBJECT_NAME
+
+# [Optional] export textured mesh with UV map, using atlas for UV unwraping.
+python core/main.py --config configs/tech_texture_export.yaml --exp_dir $EXP_DIR --sub_name $SUBJECT_NAME --test
\ No newline at end of file
diff --git a/thirdparties/lpips/LICENSE b/thirdparties/lpips/LICENSE
new file mode 100755
index 0000000..842c363
--- /dev/null
+++ b/thirdparties/lpips/LICENSE
@@ -0,0 +1,23 @@
+Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/thirdparties/lpips/__init__.py b/thirdparties/lpips/__init__.py
new file mode 100755
index 0000000..2436565
--- /dev/null
+++ b/thirdparties/lpips/__init__.py
@@ -0,0 +1,148 @@
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import torch
+
+from .lpips import *
+
+
+def normalize_tensor(in_feat,eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)+eps)
+ return in_feat/(norm_factor+eps)
+
+def l2(p0, p1, range=255.):
+ return .5*np.mean((p0 / range - p1 / range)**2)
+
+def psnr(p0, p1, peak=255.):
+ return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
+
+def dssim(p0, p1, range=255.):
+ from skimage.measure import compare_ssim
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
+
+def rgb2lab(in_img,mean_cent=False):
+ from skimage import color
+ img_lab = color.rgb2lab(in_img)
+ if(mean_cent):
+ img_lab[:,:,0] = img_lab[:,:,0]-50
+ return img_lab
+
+def tensor2np(tensor_obj):
+ # change dimension of a tensor object into a numpy array
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
+
+def np2tensor(np_obj):
+ # change dimenion of np array into tensor array
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
+
+def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
+ # image tensor to lab tensor
+ from skimage import color
+
+ img = tensor2im(image_tensor)
+ img_lab = color.rgb2lab(img)
+ if(mc_only):
+ img_lab[:,:,0] = img_lab[:,:,0]-50
+ if(to_norm and not mc_only):
+ img_lab[:,:,0] = img_lab[:,:,0]-50
+ img_lab = img_lab/100.
+
+ return np2tensor(img_lab)
+
+def tensorlab2tensor(lab_tensor,return_inbnd=False):
+ from skimage import color
+ import warnings
+ warnings.filterwarnings("ignore")
+
+ lab = tensor2np(lab_tensor)*100.
+ lab[:,:,0] = lab[:,:,0]+50
+
+ rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
+ if(return_inbnd):
+ # convert back to lab, see if we match
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
+ mask = 1.*np.isclose(lab_back,lab,atol=2.)
+ mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
+ return (im2tensor(rgb_back),mask)
+ else:
+ return im2tensor(rgb_back)
+
+def load_image(path):
+ if(path[-3:] == 'dng'):
+ import rawpy
+ with rawpy.imread(path) as raw:
+ img = raw.postprocess()
+ elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg'):
+ import cv2
+ return cv2.imread(path)[:,:,::-1]
+ else:
+ img = (255*plt.imread(path)[:,:,:3]).astype('uint8')
+
+ return img
+
+def rgb2lab(input):
+ from skimage import color
+ return color.rgb2lab(input / 255.)
+
+def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
+ image_numpy = image_tensor[0].cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
+ return image_numpy.astype(imtype)
+
+def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
+ return torch.Tensor((image / factor - cent)
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
+
+def tensor2vec(vector_tensor):
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
+
+
+def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
+# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
+ image_numpy = image_tensor[0].cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
+ return image_numpy.astype(imtype)
+
+def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
+# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
+ return torch.Tensor((image / factor - cent)
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
+
+
+
+def voc_ap(rec, prec, use_07_metric=False):
+ """ ap = voc_ap(rec, prec, [use_07_metric])
+ Compute VOC AP given precision and recall.
+ If use_07_metric is true, uses the
+ VOC 07 11 point method (default:False).
+ """
+ if use_07_metric:
+ # 11 point metric
+ ap = 0.
+ for t in np.arange(0., 1.1, 0.1):
+ if np.sum(rec >= t) == 0:
+ p = 0
+ else:
+ p = np.max(prec[rec >= t])
+ ap = ap + p / 11.
+ else:
+ # correct AP calculation
+ # first append sentinel values at the end
+ mrec = np.concatenate(([0.], rec, [1.]))
+ mpre = np.concatenate(([0.], prec, [0.]))
+
+ # compute the precision envelope
+ for i in range(mpre.size - 1, 0, -1):
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+ # to calculate area under PR curve, look for points
+ # where X axis (recall) changes value
+ i = np.where(mrec[1:] != mrec[:-1])[0]
+
+ # and sum (\Delta recall) * prec
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+ return ap
+
diff --git a/thirdparties/lpips/lpips.py b/thirdparties/lpips/lpips.py
new file mode 100755
index 0000000..8e1f552
--- /dev/null
+++ b/thirdparties/lpips/lpips.py
@@ -0,0 +1,218 @@
+
+from __future__ import absolute_import
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from torch.autograd import Variable
+import numpy as np
+from . import pretrained_networks as pn
+import torch.nn
+import thirdparties.lpips as lpips
+
+def spatial_average(in_tens, keepdim=True):
+ return in_tens.mean([2,3],keepdim=keepdim)
+
+def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
+ in_H, in_W = in_tens.shape[2], in_tens.shape[3]
+ return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
+
+# Learned perceptual metric
+class LPIPS(nn.Module):
+ def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
+ pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
+ # lpips - [True] means with linear calibration on top of base network
+ # pretrained - [True] means load linear weights
+
+ super(LPIPS, self).__init__()
+ if(verbose):
+ print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
+ ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
+
+ self.pnet_type = net
+ self.pnet_tune = pnet_tune
+ self.pnet_rand = pnet_rand
+ self.spatial = spatial
+ self.lpips = lpips # false means baseline of just averaging all layers
+ self.version = version
+ self.scaling_layer = ScalingLayer()
+
+ if(self.pnet_type in ['vgg','vgg16']):
+ net_type = pn.vgg16
+ self.chns = [64,128,256,512,512]
+ elif(self.pnet_type=='alex'):
+ net_type = pn.alexnet
+ self.chns = [64,192,384,256,256]
+ elif(self.pnet_type=='squeeze'):
+ net_type = pn.squeezenet
+ self.chns = [64,128,256,384,384,512,512]
+ self.L = len(self.chns)
+
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
+
+ if(lpips):
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
+ self.lins+=[self.lin5,self.lin6]
+ self.lins = nn.ModuleList(self.lins)
+
+ if(pretrained):
+ if(model_path is None):
+ import inspect
+ import os
+ model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))
+
+ if(verbose):
+ print('Loading model from: %s'%model_path)
+ self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
+
+ if(eval_mode):
+ self.eval()
+
+ def forward(self, in0, in1, retPerLayer=False, normalize=False):
+ if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
+ in0 = 2 * in0 - 1
+ in1 = 2 * in1 - 1
+
+ # v0.0 - original release had a bug, where input was not scaled
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+
+ for kk in range(self.L):
+ feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
+
+ if(self.lpips):
+ if(self.spatial):
+ res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
+ else:
+ res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
+ else:
+ if(self.spatial):
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
+ else:
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
+
+ val = res[0]
+ for l in range(1,self.L):
+ val += res[l]
+
+ # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
+ # b = torch.max(self.lins[kk](feats0[kk]**2))
+ # for kk in range(self.L):
+ # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
+ # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
+ # a = a/self.L
+ # from IPython import embed
+ # embed()
+ # return 10*torch.log10(b/a)
+
+ if(retPerLayer):
+ return (val, res)
+ else:
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ ''' A single linear layer which does a 1x1 conv '''
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+
+ layers = [nn.Dropout(),] if(use_dropout) else []
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+class Dist2LogitLayer(nn.Module):
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
+ def __init__(self, chn_mid=32, use_sigmoid=True):
+ super(Dist2LogitLayer, self).__init__()
+
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
+ layers += [nn.LeakyReLU(0.2,True),]
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
+ layers += [nn.LeakyReLU(0.2,True),]
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
+ if(use_sigmoid):
+ layers += [nn.Sigmoid(),]
+ self.model = nn.Sequential(*layers)
+
+ def forward(self,d0,d1,eps=0.1):
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
+
+class BCERankingLoss(nn.Module):
+ def __init__(self, chn_mid=32):
+ super(BCERankingLoss, self).__init__()
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
+ # self.parameters = list(self.net.parameters())
+ self.loss = torch.nn.BCELoss()
+
+ def forward(self, d0, d1, judge):
+ per = (judge+1.)/2.
+ self.logit = self.net.forward(d0,d1)
+ return self.loss(self.logit, per)
+
+# L2, DSSIM metrics
+class FakeNet(nn.Module):
+ def __init__(self, use_gpu=True, colorspace='Lab'):
+ super(FakeNet, self).__init__()
+ self.use_gpu = use_gpu
+ self.colorspace = colorspace
+
+class L2(FakeNet):
+ def forward(self, in0, in1, retPerLayer=None):
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
+
+ if(self.colorspace=='RGB'):
+ (N,C,X,Y) = in0.size()
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
+ return value
+ elif(self.colorspace=='Lab'):
+ value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)),
+ lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
+ ret_var = Variable( torch.Tensor((value,) ) )
+ if(self.use_gpu):
+ ret_var = ret_var.cuda()
+ return ret_var
+
+class DSSIM(FakeNet):
+
+ def forward(self, in0, in1, retPerLayer=None):
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
+
+ if(self.colorspace=='RGB'):
+ value = lpips.dssim(1.*lpips.tensor2im(in0.data), 1.*lpips.tensor2im(in1.data), range=255.).astype('float')
+ elif(self.colorspace=='Lab'):
+ value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)),
+ lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
+ ret_var = Variable( torch.Tensor((value,) ) )
+ if(self.use_gpu):
+ ret_var = ret_var.cuda()
+ return ret_var
+
+def print_network(net):
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print('Network',net)
+ print('Total number of parameters: %d' % num_params)
diff --git a/thirdparties/lpips/pretrained_networks.py b/thirdparties/lpips/pretrained_networks.py
new file mode 100755
index 0000000..a70ebbe
--- /dev/null
+++ b/thirdparties/lpips/pretrained_networks.py
@@ -0,0 +1,180 @@
+from collections import namedtuple
+import torch
+from torchvision import models as tv
+
+class squeezenet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(squeezenet, self).__init__()
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.slice6 = torch.nn.Sequential()
+ self.slice7 = torch.nn.Sequential()
+ self.N_slices = 7
+ for x in range(2):
+ self.slice1.add_module(str(x), pretrained_features[x])
+ for x in range(2,5):
+ self.slice2.add_module(str(x), pretrained_features[x])
+ for x in range(5, 8):
+ self.slice3.add_module(str(x), pretrained_features[x])
+ for x in range(8, 10):
+ self.slice4.add_module(str(x), pretrained_features[x])
+ for x in range(10, 11):
+ self.slice5.add_module(str(x), pretrained_features[x])
+ for x in range(11, 12):
+ self.slice6.add_module(str(x), pretrained_features[x])
+ for x in range(12, 13):
+ self.slice7.add_module(str(x), pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1 = h
+ h = self.slice2(h)
+ h_relu2 = h
+ h = self.slice3(h)
+ h_relu3 = h
+ h = self.slice4(h)
+ h_relu4 = h
+ h = self.slice5(h)
+ h_relu5 = h
+ h = self.slice6(h)
+ h_relu6 = h
+ h = self.slice7(h)
+ h_relu7 = h
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
+
+ return out
+
+
+class alexnet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(alexnet, self).__init__()
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(2):
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(2, 5):
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(5, 8):
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(8, 10):
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(10, 12):
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1 = h
+ h = self.slice2(h)
+ h_relu2 = h
+ h = self.slice3(h)
+ h_relu3 = h
+ h = self.slice4(h)
+ h_relu4 = h
+ h = self.slice5(h)
+ h_relu5 = h
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
+
+ return out
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+
+ return out
+
+
+
+class resnet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
+ super(resnet, self).__init__()
+ if(num==18):
+ self.net = tv.resnet18(pretrained=pretrained)
+ elif(num==34):
+ self.net = tv.resnet34(pretrained=pretrained)
+ elif(num==50):
+ self.net = tv.resnet50(pretrained=pretrained)
+ elif(num==101):
+ self.net = tv.resnet101(pretrained=pretrained)
+ elif(num==152):
+ self.net = tv.resnet152(pretrained=pretrained)
+ self.N_slices = 5
+
+ self.conv1 = self.net.conv1
+ self.bn1 = self.net.bn1
+ self.relu = self.net.relu
+ self.maxpool = self.net.maxpool
+ self.layer1 = self.net.layer1
+ self.layer2 = self.net.layer2
+ self.layer3 = self.net.layer3
+ self.layer4 = self.net.layer4
+
+ def forward(self, X):
+ h = self.conv1(X)
+ h = self.bn1(h)
+ h = self.relu(h)
+ h_relu1 = h
+ h = self.maxpool(h)
+ h = self.layer1(h)
+ h_conv2 = h
+ h = self.layer2(h)
+ h_conv3 = h
+ h = self.layer3(h)
+ h_conv4 = h
+ h = self.layer4(h)
+ h_conv5 = h
+
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
+
+ return out
diff --git a/utils/body_utils/configs/body.yaml b/utils/body_utils/configs/body.yaml
new file mode 100755
index 0000000..44bbe21
--- /dev/null
+++ b/utils/body_utils/configs/body.yaml
@@ -0,0 +1,190 @@
+name: body
+ckpt_dir: "./data/body_data/ckpt/"
+normal_path: "./data/body_data/ckpt/normal.ckpt"
+results_path: "./results"
+
+net:
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
+ in_geo: (('normal_F',3), ('normal_B',3))
+
+test_mode: True
+batch_size: 1
+
+dataset:
+ prior_type: "SMPL"
+
+SOLVER:
+ MAX_ITER: 500000
+ TYPE: Adam
+ BASE_LR: 0.00005
+ GAMMA: 0.1
+ STEPS: [0]
+ EPOCHS: [0]
+# DEBUG: False
+LOGDIR: ''
+DEVICE: cuda
+# NUM_WORKERS: 8
+SEED_VALUE: -1
+LOSS:
+ KP_2D_W: 300.0
+ KP_3D_W: 300.0
+ HF_KP_2D_W: 1000.0
+ HF_KP_3D_W: 1000.0
+ GL_HF_KP_2D_W: 30.
+ FEET_KP_2D_W: 0.
+ SHAPE_W: 0.06
+ POSE_W: 60.0
+ VERT_W: 0.0
+ VERT_REG_W: 300.0
+ INDEX_WEIGHTS: 2.0
+ # Loss weights for surface parts. (24 Parts)
+ PART_WEIGHTS: 0.3
+ # Loss weights for UV regression.
+ POINT_REGRESSION_WEIGHTS: 0.5
+TRAIN:
+ NUM_WORKERS: 8
+ BATCH_SIZE: 64
+ LOG_FERQ: 100
+ SHUFFLE: True
+ PIN_MEMORY: True
+ BHF_MODE: 'full_body'
+TEST:
+ BATCH_SIZE: 32
+MODEL:
+ # IWP, Identity rotation and Weak Perspective Camera
+ USE_IWP_CAM: True
+ USE_GT_FL: False
+ PRED_PITCH: False
+ MESH_MODEL: 'smplx'
+ ALL_GENDER: False
+ EVAL_MODE: True
+ PyMAF:
+ BACKBONE: 'hr48'
+ HF_BACKBONE: 'res50'
+ MAF_ON: True
+ MLP_DIM: [256, 128, 64, 5]
+ HF_MLP_DIM: [256, 128, 64, 5]
+ MLP_VT_DIM: [256, 128, 64, 3]
+ N_ITER: 3
+ SUPV_LAST: False
+ AUX_SUPV_ON: True
+ HF_AUX_SUPV_ON: False
+ HF_BOX_CENTER: True
+ DP_HEATMAP_SIZE: 56
+ GRID_FEAT: False
+ USE_CAM_FEAT: True
+ HF_IMG_SIZE: 224
+ HF_DET: 'pifpaf'
+ OPT_WRIST: True
+ ADAPT_INTEGR: True
+ PRED_VIS_H: True
+ HAND_VIS_TH: 0.1
+ GRID_ALIGN:
+ USE_ATT: True
+ USE_FC: False
+ ATT_FEAT_IDX: 2
+ ATT_HEAD: 1
+ ATT_STARTS: 1
+RES_MODEL:
+ DECONV_WITH_BIAS: False
+ NUM_DECONV_LAYERS: 3
+ NUM_DECONV_FILTERS:
+ - 256
+ - 256
+ - 256
+ NUM_DECONV_KERNELS:
+ - 4
+ - 4
+ - 4
+POSE_RES_MODEL:
+ INIT_WEIGHTS: True
+ NAME: 'pose_resnet'
+ PRETR_SET: 'imagenet' # 'none' 'imagenet' 'coco'
+ # PRETRAINED: 'data/pretrained_model/resnet50-19c8e357.pth'
+ PRETRAINED_IM: 'data/pretrained_model/resnet50-19c8e357.pth'
+ PRETRAINED_COCO: 'data/pretrained_model/pose_resnet_50_256x192.pth.tar'
+ EXTRA:
+ TARGET_TYPE: 'gaussian'
+ HEATMAP_SIZE:
+ - 48
+ - 64
+ SIGMA: 2
+ FINAL_CONV_KERNEL: 1
+ DECONV_WITH_BIAS: False
+ NUM_DECONV_LAYERS: 3
+ NUM_DECONV_FILTERS:
+ - 256
+ - 256
+ - 256
+ NUM_DECONV_KERNELS:
+ - 4
+ - 4
+ - 4
+ NUM_LAYERS: 50
+HR_MODEL:
+ INIT_WEIGHTS: True
+ NAME: pose_hrnet
+ PRETR_SET: 'coco' # 'none' 'imagenet' 'coco'
+ PRETRAINED_IM: 'data/pretrained_model/hrnet_w48-imgnet-8ef0771d.pth'
+ PRETRAINED_COCO: 'data/pretrained_model/pose_hrnet_w48_256x192.pth'
+ TARGET_TYPE: gaussian
+ IMAGE_SIZE:
+ - 256
+ - 256
+ HEATMAP_SIZE:
+ - 64
+ - 64
+ SIGMA: 2
+ EXTRA:
+ PRETRAINED_LAYERS:
+ - 'conv1'
+ - 'bn1'
+ - 'conv2'
+ - 'bn2'
+ - 'layer1'
+ - 'transition1'
+ - 'stage2'
+ - 'transition2'
+ - 'stage3'
+ - 'transition3'
+ - 'stage4'
+ FINAL_CONV_KERNEL: 1
+ STAGE2:
+ NUM_MODULES: 1
+ NUM_BRANCHES: 2
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 48
+ - 96
+ FUSE_METHOD: SUM
+ STAGE3:
+ NUM_MODULES: 4
+ NUM_BRANCHES: 3
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 48
+ - 96
+ - 192
+ FUSE_METHOD: SUM
+ STAGE4:
+ NUM_MODULES: 3
+ NUM_BRANCHES: 4
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 48
+ - 96
+ - 192
+ - 384
+ FUSE_METHOD: SUM
diff --git a/utils/body_utils/lib/IFGeo.py b/utils/body_utils/lib/IFGeo.py
new file mode 100755
index 0000000..8cb033d
--- /dev/null
+++ b/utils/body_utils/lib/IFGeo.py
@@ -0,0 +1,178 @@
+# -*- coding: utf-8 -*-
+
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# You can only use this computer program if you have closed
+# a license agreement with MPG or you get the right to use the computer
+# program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and
+# liable to prosecution.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# Contact: ps-license@tuebingen.mpg.de
+
+from lib.common.seg3d_lossless import Seg3dLossless
+from lib.common.train_util import *
+import torch
+import numpy as np
+import pytorch_lightning as pl
+
+torch.backends.cudnn.benchmark = True
+
+
+class IFGeo(pl.LightningModule):
+ def __init__(self, cfg):
+ super(IFGeo, self).__init__()
+
+ self.cfg = cfg
+ self.batch_size = self.cfg.batch_size
+ self.lr_G = self.cfg.lr_G
+
+ self.use_sdf = cfg.sdf
+ self.mcube_res = cfg.mcube_res
+ self.clean_mesh_flag = cfg.clean_mesh
+ self.overfit = cfg.overfit
+
+ if cfg.dataset.prior_type == "SMPL":
+ from lib.net.IFGeoNet import IFGeoNet
+ self.netG = IFGeoNet(cfg)
+ else:
+ from lib.net.IFGeoNet_nobody import IFGeoNet
+ self.netG = IFGeoNet(cfg)
+
+ self.resolutions = (
+ np.logspace(
+ start=5,
+ stop=np.log2(self.mcube_res),
+ base=2,
+ num=int(np.log2(self.mcube_res) - 4),
+ endpoint=True,
+ ) + 1.0
+ )
+
+ self.resolutions = self.resolutions.astype(np.int16).tolist()
+
+ self.reconEngine = Seg3dLossless(
+ query_func=query_func_IF,
+ b_min=[[-1.0, 1.0, -1.0]],
+ b_max=[[1.0, -1.0, 1.0]],
+ resolutions=self.resolutions,
+ align_corners=True,
+ balance_value=0.50,
+ visualize=False,
+ debug=False,
+ use_cuda_impl=False,
+ faster=True,
+ )
+
+ self.export_dir = None
+ self.result_eval = {}
+
+ # Training related
+ def configure_optimizers(self):
+
+ # set optimizer
+ weight_decay = self.cfg.weight_decay
+ momentum = self.cfg.momentum
+
+ optim_params_G = [{"params": self.netG.parameters(), "lr": self.lr_G}]
+
+ if self.cfg.optim == "Adadelta":
+
+ optimizer_G = torch.optim.Adadelta(
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
+ )
+
+ elif self.cfg.optim == "Adam":
+
+ optimizer_G = torch.optim.Adam(optim_params_G, lr=self.lr_G, weight_decay=weight_decay)
+
+ elif self.cfg.optim == "RMSprop":
+
+ optimizer_G = torch.optim.RMSprop(
+ optim_params_G,
+ lr=self.lr_G,
+ weight_decay=weight_decay,
+ momentum=momentum,
+ )
+
+ else:
+ raise NotImplementedError
+
+ # set scheduler
+ scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
+ )
+
+ return [optimizer_G], [scheduler_G]
+
+ def training_step(self, batch, batch_idx):
+
+ self.netG.train()
+
+ preds_G = self.netG(batch)
+ error_G = self.netG.compute_loss(preds_G, batch["labels_geo"])
+
+ # metrics processing
+ metrics_log = {
+ "loss": error_G,
+ }
+
+ self.log_dict(
+ metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
+ )
+
+ return metrics_log
+
+ def training_epoch_end(self, outputs):
+
+ # metrics processing
+ metrics_log = {
+ "train/avgloss": batch_mean(outputs, "loss"),
+ }
+
+ self.log_dict(
+ metrics_log,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True,
+ rank_zero_only=True
+ )
+
+ def validation_step(self, batch, batch_idx):
+
+ self.netG.eval()
+ self.netG.training = False
+
+ preds_G = self.netG(batch)
+ error_G = self.netG.compute_loss(preds_G, batch["labels_geo"])
+
+ metrics_log = {
+ "val/loss": error_G,
+ }
+
+ self.log_dict(
+ metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True
+ )
+
+ return metrics_log
+
+ def validation_epoch_end(self, outputs):
+
+ # metrics processing
+ metrics_log = {
+ "val/avgloss": batch_mean(outputs, "val/loss"),
+ }
+
+ self.log_dict(
+ metrics_log,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True,
+ rank_zero_only=True
+ )
diff --git a/utils/body_utils/lib/Normal.py b/utils/body_utils/lib/Normal.py
new file mode 100755
index 0000000..235c0ae
--- /dev/null
+++ b/utils/body_utils/lib/Normal.py
@@ -0,0 +1,217 @@
+from lib.net import NormalNet
+from lib.common.train_util import batch_mean
+import torch
+import numpy as np
+from skimage.transform import resize
+import pytorch_lightning as pl
+
+
+class Normal(pl.LightningModule):
+ def __init__(self, cfg):
+ super(Normal, self).__init__()
+ self.cfg = cfg
+ self.batch_size = self.cfg.batch_size
+ self.lr_F = self.cfg.lr_netF
+ self.lr_B = self.cfg.lr_netB
+ self.lr_D = self.cfg.lr_netD
+ self.overfit = cfg.overfit
+
+ self.F_losses = [item[0] for item in self.cfg.net.front_losses]
+ self.B_losses = [item[0] for item in self.cfg.net.back_losses]
+ self.ALL_losses = self.F_losses + self.B_losses
+
+ self.automatic_optimization = False
+
+ self.schedulers = []
+
+ self.netG = NormalNet(self.cfg)
+
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
+
+ # Training related
+ def configure_optimizers(self):
+
+ optim_params_N_D = None
+ optimizer_N_D = None
+ scheduler_N_D = None
+
+ # set optimizer
+ optim_params_N_F = [{"params": self.netG.netF.parameters(), "lr": self.lr_F}]
+ optim_params_N_B = [{"params": self.netG.netB.parameters(), "lr": self.lr_B}]
+
+ optimizer_N_F = torch.optim.Adam(optim_params_N_F, lr=self.lr_F, betas=(0.5, 0.999))
+ optimizer_N_B = torch.optim.Adam(optim_params_N_B, lr=self.lr_B, betas=(0.5, 0.999))
+
+ scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
+ )
+
+ scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
+ )
+ if 'gan' in self.ALL_losses:
+ optim_params_N_D = [{"params": self.netG.netD.parameters(), "lr": self.lr_D}]
+ optimizer_N_D = torch.optim.Adam(optim_params_N_D, lr=self.lr_D, betas=(0.5, 0.999))
+ scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer_N_D, milestones=self.cfg.schedule, gamma=self.cfg.gamma
+ )
+ self.schedulers = [scheduler_N_F, scheduler_N_B, scheduler_N_D]
+ optims = [optimizer_N_F, optimizer_N_B, optimizer_N_D]
+
+ else:
+ self.schedulers = [scheduler_N_F, scheduler_N_B]
+ optims = [optimizer_N_F, optimizer_N_B]
+
+ return optims, self.schedulers
+
+ def render_func(self, render_tensor, dataset, idx):
+
+ height = render_tensor["image"].shape[2]
+ result_list = []
+
+ for name in render_tensor.keys():
+ result_list.append(
+ resize(
+ ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(1, 2, 0),
+ (height, height),
+ anti_aliasing=True,
+ )
+ )
+
+ self.logger.log_image(
+ key=f"Normal/{dataset}/{idx if not self.overfit else 1}",
+ images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8)]
+ )
+
+ def training_step(self, batch, batch_idx):
+
+ self.netG.train()
+
+ # retrieve the data
+ in_tensor = {}
+ for name in self.in_nml:
+ in_tensor[name] = batch[name]
+
+ FB_tensor = {"normal_F": batch["normal_F"], "normal_B": batch["normal_B"]}
+
+ in_tensor.update(FB_tensor)
+
+ preds_F, preds_B = self.netG(in_tensor)
+ error_dict = self.netG.get_norm_error(preds_F, preds_B, FB_tensor)
+
+ if 'gan' in self.ALL_losses:
+ (opt_F, opt_B, opt_D) = self.optimizers()
+ opt_F.zero_grad()
+ self.manual_backward(error_dict["netF"])
+ opt_B.zero_grad()
+ self.manual_backward(error_dict["netB"], retain_graph=True)
+ opt_D.zero_grad()
+ self.manual_backward(error_dict["netD"])
+ opt_F.step()
+ opt_B.step()
+ opt_D.step()
+ else:
+ (opt_F, opt_B) = self.optimizers()
+ opt_F.zero_grad()
+ self.manual_backward(error_dict["netF"])
+ opt_B.zero_grad()
+ self.manual_backward(error_dict["netB"])
+ opt_F.step()
+ opt_B.step()
+
+ if batch_idx > 0 and batch_idx % int(
+ self.cfg.freq_show_train
+ ) == 0 and self.cfg.devices == 1:
+
+ self.netG.eval()
+ with torch.no_grad():
+ nmlF, nmlB = self.netG(in_tensor)
+ in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
+ self.render_func(in_tensor, "train", self.global_step)
+
+ # metrics processing
+ metrics_log = {"loss": error_dict["netF"] + error_dict["netB"]}
+
+ if "gan" in self.ALL_losses:
+ metrics_log["loss"] += error_dict["netD"]
+
+ for key in error_dict.keys():
+ metrics_log["train/loss_" + key] = error_dict[key].item()
+
+ self.log_dict(
+ metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
+ )
+
+ return metrics_log
+
+ def training_epoch_end(self, outputs):
+
+ # metrics processing
+ metrics_log = {}
+ for key in outputs[0].keys():
+ if "/" in key:
+ [stage, loss_name] = key.split("/")
+ else:
+ stage = "train"
+ loss_name = key
+ metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
+
+ self.log_dict(
+ metrics_log,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True,
+ rank_zero_only=True
+ )
+
+ def validation_step(self, batch, batch_idx):
+
+ self.netG.eval()
+ self.netG.training = False
+
+ # retrieve the data
+ in_tensor = {}
+ for name in self.in_nml:
+ in_tensor[name] = batch[name]
+
+ FB_tensor = {"normal_F": batch["normal_F"], "normal_B": batch["normal_B"]}
+ in_tensor.update(FB_tensor)
+
+ preds_F, preds_B = self.netG(in_tensor)
+ error_dict = self.netG.get_norm_error(preds_F, preds_B, FB_tensor)
+
+ if batch_idx % int(self.cfg.freq_show_train) == 0 and self.cfg.devices == 1:
+
+ with torch.no_grad():
+ nmlF, nmlB = self.netG(in_tensor)
+ in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
+ self.render_func(in_tensor, "val", batch_idx)
+
+ # metrics processing
+ metrics_log = {"val/loss": error_dict["netF"] + error_dict["netB"]}
+
+ if "gan" in self.ALL_losses:
+ metrics_log["val/loss"] += error_dict["netD"]
+
+ for key in error_dict.keys():
+ metrics_log["val/" + key] = error_dict[key].item()
+
+ return metrics_log
+
+ def validation_epoch_end(self, outputs):
+
+ # metrics processing
+ metrics_log = {}
+ for key in outputs[0].keys():
+ [stage, loss_name] = key.split("/")
+ metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
+
+ self.log_dict(
+ metrics_log,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True,
+ rank_zero_only=True
+ )
diff --git a/utils/body_utils/lib/__init__.py b/utils/body_utils/lib/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/utils/body_utils/lib/common/__init__.py b/utils/body_utils/lib/common/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/utils/body_utils/lib/common/blender_utils.py b/utils/body_utils/lib/common/blender_utils.py
new file mode 100755
index 0000000..a02260c
--- /dev/null
+++ b/utils/body_utils/lib/common/blender_utils.py
@@ -0,0 +1,383 @@
+import bpy
+import sys, os
+from math import radians
+import mathutils
+import bmesh
+
+print(sys.exec_prefix)
+from tqdm import tqdm
+import numpy as np
+
+##################################################
+# Globals
+##################################################
+
+views = 120
+
+render = 'eevee'
+cycles_gpu = False
+
+quality_preview = False
+samples_preview = 16
+samples_final = 256
+
+resolution_x = 512
+resolution_y = 512
+
+shadows = False
+
+# diffuse_color = (57.0/255.0, 108.0/255.0, 189.0/255.0, 1.0)
+# diffuse_color = (18/255., 139/255., 142/255.,1) #correct
+# diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
+
+smooth = False
+
+wireframe = False
+line_thickness = 0.1
+quads = False
+
+object_transparent = False
+mouth_transparent = False
+
+compositor_background_image = False
+compositor_image_scale = 1.0
+compositor_alpha = 0.7
+
+##################################################
+# Helper functions
+##################################################
+
+
+def blender_print(*args, **kwargs):
+ print(*args, **kwargs, file=sys.stderr)
+
+
+def using_app():
+ ''' Returns if script is running through Blender application (GUI or background processing)'''
+ return (not sys.argv[0].endswith('.py'))
+
+
+def setup_diffuse_transparent_material(target, color, object_transparent, backface_transparent):
+ ''' Sets up diffuse/transparent material with backface culling in cycles'''
+
+ mat = target.active_material
+ if mat is None:
+ # Create material
+ mat = bpy.data.materials.new(name='Material')
+ target.data.materials.append(mat)
+
+ mat.use_nodes = True
+ nodes = mat.node_tree.nodes
+ for node in nodes:
+ nodes.remove(node)
+
+ node_geometry = nodes.new('ShaderNodeNewGeometry')
+
+ node_diffuse = nodes.new('ShaderNodeBsdfDiffuse')
+ node_diffuse.inputs[0].default_value = color
+
+ node_transparent = nodes.new('ShaderNodeBsdfTransparent')
+ node_transparent.inputs[0].default_value = (1.0, 1.0, 1.0, 1.0)
+
+ node_emission = nodes.new('ShaderNodeEmission')
+ node_emission.inputs[0].default_value = (0.0, 0.0, 0.0, 1.0)
+
+ node_mix = nodes.new(type='ShaderNodeMixShader')
+ if object_transparent:
+ node_mix.inputs[0].default_value = 1.0
+ else:
+ node_mix.inputs[0].default_value = 0.0
+
+ node_mix_mouth = nodes.new(type='ShaderNodeMixShader')
+ if object_transparent or backface_transparent:
+ node_mix_mouth.inputs[0].default_value = 1.0
+ else:
+ node_mix_mouth.inputs[0].default_value = 0.0
+
+ node_mix_backface = nodes.new(type='ShaderNodeMixShader')
+
+ node_output = nodes.new(type='ShaderNodeOutputMaterial')
+
+ links = mat.node_tree.links
+
+ links.new(node_geometry.outputs[6], node_mix_backface.inputs[0])
+
+ links.new(node_diffuse.outputs[0], node_mix.inputs[1])
+ links.new(node_transparent.outputs[0], node_mix.inputs[2])
+ links.new(node_mix.outputs[0], node_mix_backface.inputs[1])
+
+ links.new(node_emission.outputs[0], node_mix_mouth.inputs[1])
+ links.new(node_transparent.outputs[0], node_mix_mouth.inputs[2])
+ links.new(node_mix_mouth.outputs[0], node_mix_backface.inputs[2])
+
+ links.new(node_mix_backface.outputs[0], node_output.inputs[0])
+ return
+
+
+##################################################
+
+
+def setup_scene():
+ global render
+ global cycles_gpu
+ global quality_preview
+ global resolution_x
+ global resolution_y
+ global shadows
+ global wireframe
+ global line_thickness
+ global compositor_background_image
+
+ # Remove default cube
+ if 'Cube' in bpy.data.objects:
+ bpy.data.objects['Cube'].select_set(True)
+ bpy.ops.object.delete()
+
+ scene = bpy.data.scenes['Scene']
+
+ # Setup render engine
+ if render == 'cycles':
+ scene.render.engine = 'CYCLES'
+ else:
+ scene.render.engine = 'BLENDER_EEVEE'
+
+ scene.render.resolution_x = resolution_x
+ scene.render.resolution_y = resolution_y
+ scene.render.resolution_percentage = 100
+ scene.render.film_transparent = True
+ if quality_preview:
+ scene.cycles.samples = samples_preview
+ else:
+ scene.cycles.samples = samples_final
+
+ # Setup Cycles CUDA GPU acceleration if requested
+ if render == 'cycles':
+ if cycles_gpu:
+ print('Activating GPU acceleration')
+ bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
+
+ if bpy.app.version[0] >= 3:
+ cuda_devices = bpy.context.preferences.addons[
+ 'cycles'].preferences.get_devices_for_type(compute_device_type='CUDA')
+ else:
+ (cuda_devices, opencl_devices
+ ) = bpy.context.preferences.addons['cycles'].preferences.get_devices()
+
+ if (len(cuda_devices) < 1):
+ print('ERROR: CUDA GPU acceleration not available')
+ sys.exit(1)
+
+ for cuda_device in cuda_devices:
+ if cuda_device.type == 'CUDA':
+ cuda_device.use = True
+ print('Using CUDA device: ' + str(cuda_device.name))
+ else:
+ cuda_device.use = False
+ print('Igoring CUDA device: ' + str(cuda_device.name))
+
+ scene.cycles.device = 'GPU'
+ if bpy.app.version[0] < 3:
+ scene.render.tile_x = 256
+ scene.render.tile_y = 256
+ else:
+ scene.cycles.device = 'CPU'
+ if bpy.app.version[0] < 3:
+ scene.render.tile_x = 64
+ scene.render.tile_y = 64
+
+ # Disable Blender 3 denoiser to properly measure Cycles render speed
+ if bpy.app.version[0] >= 3:
+ scene.cycles.use_denoising = False
+
+ # Setup camera
+ camera = bpy.data.objects['Camera']
+ camera.location = (0.0, -3, 1.8)
+ camera.rotation_euler = (radians(74), 0.0, 0)
+ bpy.data.cameras['Camera'].lens = 55
+
+ # Setup light
+
+ # Setup lights
+ light = bpy.data.objects['Light']
+ light.location = (-2, -3.0, 0.0)
+ light.rotation_euler = (radians(90.0), 0.0, 0.0)
+ bpy.data.lights['Light'].type = 'POINT'
+ bpy.data.lights['Light'].energy = 2
+ light.data.cycles.cast_shadow = False
+
+ if 'Sun' not in bpy.data.objects:
+ bpy.ops.object.light_add(type='SUN')
+ light_sun = bpy.context.active_object
+ light_sun.location = (0.0, -3, 0.0)
+ light_sun.rotation_euler = (radians(45.0), 0.0, radians(30))
+ bpy.data.lights['Sun'].energy = 2
+ light_sun.data.cycles.cast_shadow = shadows
+ else:
+ light_sun = bpy.data.objects['Sun']
+
+ if shadows:
+ # Setup shadow catcher
+ bpy.ops.mesh.primitive_plane_add()
+ plane = bpy.context.active_object
+ plane.scale = (5.0, 5.0, 1)
+
+ plane.cycles.is_shadow_catcher = True
+
+ # Exclude plane from diffuse cycles contribution to avoid bright pixel noise in body rendering
+ # plane.cycles_visibility.diffuse = False
+
+ if wireframe:
+ # Unmark freestyle edges
+ bpy.ops.object.mode_set(mode='EDIT')
+ bpy.ops.mesh.mark_freestyle_edge(clear=True)
+ bpy.ops.object.mode_set(mode='OBJECT')
+
+ # Setup freestyle mode for wireframe overlay rendering
+ if wireframe:
+ scene.render.use_freestyle = True
+ scene.render.line_thickness = line_thickness
+ bpy.context.view_layer.freestyle_settings.linesets[0].select_edge_mark = True
+
+ # Disable border edges so that we don't see contour of shadow catcher plane
+ bpy.context.view_layer.freestyle_settings.linesets[0].select_border = False
+ else:
+ scene.render.use_freestyle = False
+
+ if compositor_background_image:
+ # Setup compositing when using background image
+ setup_compositing()
+ else:
+ # Output transparent image when no background is used
+ scene.render.image_settings.color_mode = 'RGBA'
+
+
+##################################################
+
+
+def setup_compositing():
+
+ global compositor_image_scale
+ global compositor_alpha
+
+ # Node editor compositing setup
+ bpy.context.scene.use_nodes = True
+ tree = bpy.context.scene.node_tree
+
+ # Create input image node
+ image_node = tree.nodes.new(type='CompositorNodeImage')
+
+ scale_node = tree.nodes.new(type='CompositorNodeScale')
+ scale_node.inputs[1].default_value = compositor_image_scale
+ scale_node.inputs[2].default_value = compositor_image_scale
+
+ blend_node = tree.nodes.new(type='CompositorNodeAlphaOver')
+ blend_node.inputs[0].default_value = compositor_alpha
+
+ # Link nodes
+ links = tree.links
+ links.new(image_node.outputs[0], scale_node.inputs[0])
+
+ links.new(scale_node.outputs[0], blend_node.inputs[1])
+ links.new(tree.nodes['Render Layers'].outputs[0], blend_node.inputs[2])
+
+ links.new(blend_node.outputs[0], tree.nodes['Composite'].inputs[0])
+
+
+def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
+ '''Render image of given model file'''
+ global smooth
+ global object_transparent
+ global mouth_transparent
+ global compositor_background_image
+ global quads
+
+ path = input_dir + input_file
+
+ # Import object into scene
+ bpy.ops.import_scene.obj(filepath=path)
+ object = bpy.context.selected_objects[0]
+
+ object.rotation_euler = (radians(90.0), 0.0, radians(yaw))
+ z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:, 1])
+ # z_top = np.max(np.array([vert.co for vert in object.data.vertices])[:,1])
+ # blender_print(radians(90.0), z_bottom, z_top)
+ object.location -= mathutils.Vector((0.0, 0.0, z_bottom))
+
+ if quads:
+ bpy.context.view_layer.objects.active = object
+ bpy.ops.object.mode_set(mode='EDIT')
+ bpy.ops.mesh.tris_convert_to_quads()
+ bpy.ops.object.mode_set(mode='OBJECT')
+
+ if smooth:
+ bpy.ops.object.shade_smooth()
+
+ # Mark freestyle edges
+ bpy.context.view_layer.objects.active = object
+ bpy.ops.object.mode_set(mode='EDIT')
+ bpy.ops.mesh.mark_freestyle_edge(clear=False)
+ bpy.ops.object.mode_set(mode='OBJECT')
+
+ if correct:
+ diffuse_color = (18 / 255., 139 / 255., 142 / 255., 1) #correct
+ else:
+ diffuse_color = (251 / 255., 60 / 255., 60 / 255., 1) #wrong
+
+ setup_diffuse_transparent_material(object, diffuse_color, object_transparent, mouth_transparent)
+
+ if compositor_background_image:
+ # Set background image
+ image_path = input_dir + input_file.replace('.obj', '_original.png')
+ bpy.context.scene.node_tree.nodes['Image'].image = bpy.data.images.load(image_path)
+
+ # Render
+ bpy.context.scene.render.filepath = os.path.join(output_dir, output_file)
+
+ # Silence console output of bpy.ops.render.render by redirecting stdout to file
+ # Note: Does not actually write the output to file (Windows 7)
+ sys.stdout.flush()
+ old = os.dup(1)
+ os.close(1)
+ os.open('blender_render.log', os.O_WRONLY | os.O_CREAT)
+
+ # Render
+ bpy.ops.render.render(write_still=True)
+
+ # Remove temporary output redirection
+ # sys.stdout.flush()
+ # os.close(1)
+ # os.dup(old)
+ # os.close(old)
+
+ # Delete last selected object from scene
+ object.select_set(True)
+ bpy.ops.object.delete()
+
+
+def process_file(input_file, input_dir, output_file, output_dir, correct=True):
+ global views
+ global quality_preview
+
+ if not input_file.endswith('.obj'):
+ print('ERROR: Invalid input: ' + input_file)
+ return
+
+ print('Processing: ' + input_file)
+ if output_file == '':
+ output_file = input_file[:-4]
+
+ if quality_preview:
+ output_file = output_file.replace('.png', '-preview.png')
+
+ angle = 360.0 / views
+ pbar = tqdm(range(0, views))
+ for view in pbar:
+ pbar.set_description(f"{os.path.basename(output_file)} | View:{str(view)}")
+ yaw = view * angle
+ output_file_view = f"{output_file}/{view:03d}.png"
+ if not os.path.exists(os.path.join(output_dir, output_file_view)):
+ render_file(input_file, input_dir, output_file_view, output_dir, yaw, correct)
+
+ cmd = "ffmpeg -loglevel quiet -r 30 -f lavfi -i color=c=white:s=512x512 -i " + os.path.join(output_dir, output_file, '%3d.png') + \
+ " -shortest -filter_complex \"[0:v][1:v]overlay=shortest=1,format=yuv420p[out]\" -map \"[out]\" -y " + output_dir+"/"+output_file+".mp4"
+ os.system(cmd)
diff --git a/utils/body_utils/lib/common/cloth_extraction.py b/utils/body_utils/lib/common/cloth_extraction.py
new file mode 100755
index 0000000..612a967
--- /dev/null
+++ b/utils/body_utils/lib/common/cloth_extraction.py
@@ -0,0 +1,196 @@
+import numpy as np
+import json
+import os
+import itertools
+import trimesh
+from matplotlib.path import Path
+from collections import Counter
+from sklearn.neighbors import KNeighborsClassifier
+
+
+def load_segmentation(path, shape):
+ """
+ Get a segmentation mask for a given image
+ Arguments:
+ path: path to the segmentation json file
+ shape: shape of the output mask
+ Returns:
+ Returns a segmentation mask
+ """
+ with open(path) as json_file:
+ dict = json.load(json_file)
+ segmentations = []
+ for key, val in dict.items():
+ if not key.startswith("item"):
+ continue
+
+ # Each item can have multiple polygons. Combine them to one
+ # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
+ # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
+
+ coordinates = []
+ for segmentation_coord in val["segmentation"]:
+ # The format before is [x1,y1, x2, y2, ....]
+ x = segmentation_coord[::2]
+ y = segmentation_coord[1::2]
+ xy = np.vstack((x, y)).T
+ coordinates.append(xy)
+
+ segmentations.append(
+ {
+ "type": val["category_name"],
+ "type_id": val["category_id"],
+ "coordinates": coordinates,
+ }
+ )
+
+ return segmentations
+
+
+def smpl_to_recon_labels(recon, smpl, k=1):
+ """
+ Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
+ Arguments:
+ recon: trimesh object (fully clothed model)
+ shape: trimesh object (smpl model)
+ k: number of nearest neighbours to use
+ Returns:
+ Returns a dictionary containing the bodypart and the corresponding indices
+ """
+ smpl_vert_segmentation = json.load(
+ open(os.path.join(os.path.dirname(__file__), "smpl_vert_segmentation.json"))
+ )
+ n = smpl.vertices.shape[0]
+ y = np.array([None] * n)
+ for key, val in smpl_vert_segmentation.items():
+ y[val] = key
+
+ classifier = KNeighborsClassifier(n_neighbors=1)
+ classifier.fit(smpl.vertices, y)
+
+ y_pred = classifier.predict(recon.vertices)
+
+ recon_labels = {}
+ for key in smpl_vert_segmentation.keys():
+ recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int))
+
+ return recon_labels
+
+
+def extract_cloth(recon, segmentation, K, R, t, smpl=None):
+ """
+ Extract a portion of a mesh using 2d segmentation coordinates
+ Arguments:
+ recon: fully clothed mesh
+ seg_coord: segmentation coordinates in 2D (NDC)
+ K: intrinsic matrix of the projection
+ R: rotation matrix of the projection
+ t: translation vector of the projection
+ Returns:
+ Returns a submesh using the segmentation coordinates
+ """
+ seg_coord = segmentation["coord_normalized"]
+ mesh = trimesh.Trimesh(recon.vertices, recon.faces)
+ extrinsic = np.zeros((3, 4))
+ extrinsic[:3, :3] = R
+ extrinsic[:, 3] = t
+ P = K[:3, :3] @ extrinsic
+
+ P_inv = np.linalg.pinv(P)
+
+ # Each segmentation can contain multiple polygons
+ # We need to check them separately
+ points_so_far = []
+ faces = recon.faces
+ for polygon in seg_coord:
+ n = len(polygon)
+ coords_h = np.hstack((polygon, np.ones((n, 1))))
+ # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
+ XYZ = P_inv @ coords_h[:, :, None]
+ XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
+ XYZ = XYZ[:, :3] / XYZ[:, 3, None]
+
+ p = Path(XYZ[:, :2])
+
+ grid = p.contains_points(recon.vertices[:, :2])
+ indeces = np.argwhere(grid == True)
+ points_so_far += list(indeces.flatten())
+
+ if smpl is not None:
+ num_verts = recon.vertices.shape[0]
+ recon_labels = smpl_to_recon_labels(recon, smpl)
+ body_parts_to_remove = [
+ "rightHand",
+ "leftToeBase",
+ "leftFoot",
+ "rightFoot",
+ "head",
+ "leftHandIndex1",
+ "rightHandIndex1",
+ "rightToeBase",
+ "leftHand",
+ "rightHand",
+ ]
+ type = segmentation["type_id"]
+
+ # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
+ # https://github.com/switchablenorms/DeepFashion2
+ # Short sleeve clothes
+ if type == 1 or type == 3 or type == 10:
+ body_parts_to_remove += ["leftForeArm", "rightForeArm"]
+ # No sleeves at all or lower body clothes
+ elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9):
+ body_parts_to_remove += [
+ "leftForeArm",
+ "rightForeArm",
+ "leftArm",
+ "rightArm",
+ ]
+ # Shorts
+ elif type == 7:
+ body_parts_to_remove += [
+ "leftLeg",
+ "rightLeg",
+ "leftForeArm",
+ "rightForeArm",
+ "leftArm",
+ "rightArm",
+ ]
+
+ verts_to_remove = list(
+ itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove])
+ )
+
+ label_mask = np.zeros(num_verts, dtype=bool)
+ label_mask[verts_to_remove] = True
+
+ seg_mask = np.zeros(num_verts, dtype=bool)
+ seg_mask[points_so_far] = True
+
+ # Remove points that belong to other bodyparts
+ # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
+ extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
+
+ combine_mask = np.zeros(num_verts, dtype=bool)
+ combine_mask[points_so_far] = True
+ combine_mask[extra_verts_to_remove] = False
+
+ all_indices = np.argwhere(combine_mask == True).flatten()
+
+ i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
+ i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
+ i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
+
+ faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
+ mask = np.zeros(len(recon.faces), dtype=bool)
+ if len(faces_to_keep) > 0:
+ mask[faces_to_keep] = True
+
+ mesh.update_faces(mask)
+ mesh.remove_unreferenced_vertices()
+
+ # mesh.rezero()
+
+ return mesh
+
+ return None
diff --git a/utils/body_utils/lib/common/config.py b/utils/body_utils/lib/common/config.py
new file mode 100755
index 0000000..0d7b719
--- /dev/null
+++ b/utils/body_utils/lib/common/config.py
@@ -0,0 +1,259 @@
+# -*- coding: utf-8 -*-
+
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# You can only use this computer program if you have closed
+# a license agreement with MPG or you get the right to use the computer
+# program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and
+# liable to prosecution.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# Contact: ps-license@tuebingen.mpg.de
+
+from yacs.config import CfgNode as CN
+import os
+
+_C = CN(new_allowed=True)
+
+# needed by trainer
+_C.name = "default"
+_C.gpus = [0]
+_C.test_gpus = [1]
+_C.devices = 1
+_C.root = "./data/"
+_C.ckpt_dir = "./data/ckpt/"
+_C.resume_path = ""
+_C.normal_path = ""
+_C.ifnet_path = ""
+_C.results_path = "./results/"
+_C.projection_mode = "orthogonal"
+_C.num_views = 1
+_C.sdf = False
+_C.sdf_clip = 5.0
+
+_C.lr_netF = 1e-3
+_C.lr_netB = 1e-3
+_C.lr_netD = 1e-3
+_C.lr_G = 1e-3
+_C.weight_decay = 0.0
+_C.momentum = 0.0
+_C.optim = "RMSprop"
+_C.schedule = [5, 10, 15]
+_C.gamma = 0.1
+
+_C.overfit = False
+_C.resume = False
+_C.test_mode = False
+_C.test_uv = False
+_C.draw_geo_thres = 0.60
+_C.num_sanity_val_steps = 2
+_C.fast_dev = 0
+_C.get_fit = False
+_C.agora = False
+_C.optim_cloth = False
+_C.optim_body = False
+_C.mcube_res = 256
+_C.clean_mesh = True
+_C.remesh = False
+_C.body_overlap_thres = 1.0
+_C.cloth_overlap_thres = 1.0
+
+_C.batch_size = 4
+_C.num_threads = 8
+
+_C.num_epoch = 10
+_C.freq_plot = 0.01
+_C.freq_show_train = 0.1
+_C.freq_show_val = 0.2
+_C.freq_eval = 0.5
+_C.accu_grad_batch = 4
+
+_C.vol_res = 128
+
+_C.test_items = ["sv", "mv", "mv-fusion", "hybrid", "dc-pred", "gt"]
+
+_C.net = CN()
+_C.net.gtype = "HGPIFuNet"
+_C.net.ctype = "resnet18"
+_C.net.classifierIMF = "MultiSegClassifier"
+_C.net.netIMF = "resnet18"
+_C.net.norm = "group"
+_C.net.norm_mlp = "group"
+_C.net.norm_color = "group"
+_C.net.hg_down = "ave_pool"
+_C.net.num_views = 1
+
+_C.bni = CN()
+_C.bni.k = 4
+_C.bni.lambda1 = 1e-4
+_C.bni.boundary_consist = 1e-6
+_C.bni.poisson_depth = 10
+_C.bni.use_poisson = True
+_C.bni.use_smpl = ["face", "hand"]
+_C.bni.use_ifnet = False
+_C.bni.finish = False
+_C.bni.thickness = 0.00
+_C.bni.hand_thres = 4e-2
+_C.bni.face_thres = 6e-2
+_C.bni.hps_type = "pixie"
+_C.bni.texture_src = "image"
+_C.bni.cut_intersection = True
+
+# kernel_size, stride, dilation, padding
+
+_C.net.conv1 = [7, 2, 1, 3]
+_C.net.conv3x3 = [3, 1, 1, 1]
+
+_C.net.num_stack = 4
+_C.net.num_hourglass = 2
+_C.net.hourglass_dim = 256
+_C.net.voxel_dim = 32
+_C.net.resnet_dim = 120
+_C.net.mlp_dim = [320, 1024, 512, 256, 128, 1]
+_C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3]
+_C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3]
+_C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500]
+_C.net.res_layers = [2, 3, 4]
+_C.net.filter_dim = 256
+_C.net.smpl_dim = 3
+
+_C.net.cly_dim = 3
+_C.net.soft_dim = 64
+_C.net.z_size = 200.0
+_C.net.N_freqs = 10
+_C.net.geo_w = 0.1
+_C.net.norm_w = 0.1
+_C.net.dc_w = 0.1
+_C.net.C_cat_to_G = False
+
+_C.net.skip_hourglass = True
+_C.net.use_tanh = True
+_C.net.soft_onehot = True
+_C.net.no_residual = True
+_C.net.use_attention = False
+
+_C.net.prior_type = "icon"
+_C.net.smpl_feats = ["sdf", "vis"]
+_C.net.use_filter = True
+_C.net.use_cc = False
+_C.net.use_PE = False
+_C.net.use_IGR = False
+_C.net.use_gan = False
+_C.net.in_geo = ()
+_C.net.in_nml = ()
+_C.net.front_losses = ()
+_C.net.back_losses = ()
+
+_C.net.gan = CN()
+_C.net.gan.dim_detail = 64
+_C.net.gan.lambda_gan = 1
+_C.net.gan.lambda_grad = 10
+_C.net.gan.lambda_recon = 10
+_C.net.gan.d_reg_every = 16
+_C.net.gan.img_res = 512
+
+_C.dataset = CN()
+_C.dataset.root = ""
+_C.dataset.cached = True
+_C.dataset.set_splits = [0.95, 0.04]
+_C.dataset.types = [
+ # "3dpeople",
+ # "axyz",
+ # "renderpeople",
+ # "renderpeople_p27",
+ # "humanalloy",
+ #"cape"
+ "thuman2"
+]
+_C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37]
+_C.dataset.rp_type = "pifu900"
+_C.dataset.th_type = "train"
+_C.dataset.input_size = 512
+_C.dataset.rotation_num = 3
+_C.dataset.num_precomp = 10 # Number of segmentation classifiers
+_C.dataset.num_multiseg = 500 # Number of categories per classifier
+_C.dataset.num_knn = 10 # for loss/error
+_C.dataset.num_knn_dis = 20 # for accuracy
+_C.dataset.num_verts_max = 20000
+_C.dataset.zray_type = False
+_C.dataset.online_smpl = False
+_C.dataset.noise_type = ["z-trans", "pose", "beta"]
+_C.dataset.noise_scale = [0.0, 0.0, 0.0]
+_C.dataset.num_sample_geo = 10000
+_C.dataset.num_sample_color = 0
+_C.dataset.num_sample_seg = 0
+_C.dataset.num_sample_knn = 10000
+
+_C.dataset.sigma_geo = 5.0
+_C.dataset.sigma_color = 0.10
+_C.dataset.sigma_seg = 0.10
+_C.dataset.thickness_threshold = 20.0
+_C.dataset.ray_sample_num = 2
+_C.dataset.semantic_p = False
+_C.dataset.remove_outlier = False
+_C.dataset.laplacian_iters = 0
+_C.dataset.prior_type = "smpl"
+_C.dataset.voxel_res = 128
+
+_C.dataset.train_bsize = 1.0
+_C.dataset.val_bsize = 1.0
+_C.dataset.test_bsize = 1.0
+_C.dataset.single = True
+
+
+def get_cfg_defaults():
+ """Get a yacs CfgNode object with default values for my_project."""
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ return _C.clone()
+
+
+# Alternatively, provide a way to import the defaults as
+# a global singleton:
+cfg = _C # users can `from config import cfg`
+
+# cfg = get_cfg_defaults()
+# cfg.merge_from_file('./configs/example.yaml')
+
+# # Now override from a list (opts could come from the command line)
+# opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2']
+# cfg.merge_from_list(opts)
+
+
+def update_cfg(cfg_file):
+ # cfg = get_cfg_defaults()
+ _C.merge_from_file(cfg_file)
+ # return cfg.clone()
+ return _C
+
+
+def parse_args(args):
+ cfg_file = args.cfg_file
+ if args.cfg_file is not None:
+ cfg = update_cfg(args.cfg_file)
+ else:
+ cfg = get_cfg_defaults()
+
+ # if args.misc is not None:
+ # cfg.merge_from_list(args.misc)
+
+ return cfg
+
+
+def parse_args_extend(args):
+ if args.resume:
+ if not os.path.exists(args.log_dir):
+ raise ValueError("Experiment are set to resume mode, but log directory does not exist.")
+
+ # load log's cfg
+ cfg_file = os.path.join(args.log_dir, "cfg.yaml")
+ cfg = update_cfg(cfg_file)
+
+ if args.misc is not None:
+ cfg.merge_from_list(args.misc)
+ else:
+ parse_args(args)
diff --git a/utils/body_utils/lib/common/imutils.py b/utils/body_utils/lib/common/imutils.py
new file mode 100755
index 0000000..64bbf06
--- /dev/null
+++ b/utils/body_utils/lib/common/imutils.py
@@ -0,0 +1,361 @@
+import os
+os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
+import cv2
+import mediapipe as mp
+import torch
+import numpy as np
+import torch.nn.functional as F
+from PIL import Image
+from thirdparties.MODNet.src.models.modnet import MODNet
+
+from torchvision import transforms
+from kornia.geometry.transform import get_affine_matrix2d, warp_affine
+
+
+IMG_NORM_MEAN = [0.485, 0.456, 0.406]
+IMG_NORM_STD = [0.229, 0.224, 0.225]
+
+
+def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
+ all_ops = []
+ if res is not None:
+ all_ops.append(transforms.Resize(size=res))
+ if not is_tensor:
+ all_ops.append(transforms.ToTensor())
+ if mean is not None and std is not None:
+ all_ops.append(transforms.Normalize(mean=mean, std=std))
+ return transforms.Compose(all_ops)
+
+
+def get_affine_matrix_wh(w1, h1, w2, h2):
+
+ transl = torch.tensor([(w2 - w1) / 2.0, (h2 - h1) / 2.0]).unsqueeze(0)
+ center = torch.tensor([w1 / 2.0, h1 / 2.0]).unsqueeze(0)
+ scale = torch.min(torch.tensor([w2 / w1, h2 / h1])).repeat(2).unsqueeze(0)
+ M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
+
+ return M
+
+
+def get_affine_matrix_box(boxes, w2, h2):
+
+ # boxes [left, top, right, bottom]
+ width = boxes[:, 2] - boxes[:, 0] #(N,)
+ height = boxes[:, 3] - boxes[:, 1] #(N,)
+ center = torch.tensor(
+ [(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]
+ ).T #(N,2)
+ scale = torch.min(torch.tensor([w2 / width, h2 / height]),
+ dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9 #(N,2)
+ transl = torch.cat([w2 / 2.0 - center[:, 0:1], h2 / 2.0 - center[:, 1:2]], dim=1) #(N,2)
+ M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.,]*transl.shape[0]))
+
+ return M
+
+
+def load_img(img_file):
+
+ if img_file.endswith("exr"):
+ img = cv2.imread(img_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ else :
+ img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
+
+ # considering non 8-bit image
+ if img.dtype != np.uint8 :
+ img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
+
+ if len(img.shape) == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ if not img_file.endswith("png"):
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
+
+ return torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float(), img.shape[:2]
+
+
+def get_keypoints(image):
+ def collect_xyv(x, body=True):
+ lmk = x.landmark
+ all_lmks = []
+ for i in range(len(lmk)):
+ visibility = lmk[i].visibility if body else 1.0
+ all_lmks.append(torch.Tensor([lmk[i].x, lmk[i].y, lmk[i].z, visibility]))
+ return torch.stack(all_lmks).view(-1, 4)
+
+ mp_holistic = mp.solutions.holistic
+
+ with mp_holistic.Holistic(
+ static_image_mode=True,
+ model_complexity=2,
+ ) as holistic:
+ results = holistic.process(image)
+
+ fake_kps = torch.zeros(33, 4)
+
+ result = {}
+ result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps
+ result["lhand"] = collect_xyv(
+ results.left_hand_landmarks, False
+ ) if results.left_hand_landmarks else fake_kps
+ result["rhand"] = collect_xyv(
+ results.right_hand_landmarks, False
+ ) if results.right_hand_landmarks else fake_kps
+ result["face"] = collect_xyv(
+ results.face_landmarks, False
+ ) if results.face_landmarks else fake_kps
+
+ return result
+
+
+def remove_floats(mask):
+
+ # 1. find all the contours
+ # 2. fillPoly "True" for the largest one
+ # 3. fillPoly "False" for its childrens
+
+ new_mask = np.zeros(mask.shape)
+ cnts, hier = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
+ cnt_index = sorted(range(len(cnts)), key=lambda k: cv2.contourArea(cnts[k]), reverse=True)
+ body_cnt = cnts[cnt_index[0]]
+ childs_cnt_idx = np.where(np.array(hier)[0, :, -1] == cnt_index[0])[0]
+ childs_cnt = [cnts[idx] for idx in childs_cnt_idx]
+ cv2.fillPoly(new_mask, [body_cnt], 1)
+ cv2.fillPoly(new_mask, childs_cnt, 0)
+
+ return new_mask
+
+
+def process_image(img_file, hps_type, single, input_res, detector, modnet):
+
+ img_raw, (in_height, in_width) = load_img(img_file)
+ tgt_res = input_res * 2
+ while tgt_res < in_height or tgt_res < in_width:
+ tgt_res += input_res
+ M_square = get_affine_matrix_wh(in_width, in_height, tgt_res, tgt_res)
+ img_square = warp_affine(
+ img_raw,
+ M_square[:, :2], (tgt_res, ) * 2,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True
+ )
+
+ # detection for bbox
+ predictions = detector(img_square / 255.)[0]
+
+ if single:
+ top_score = predictions["scores"][predictions["labels"] == 1].max()
+ human_ids = torch.where(predictions["scores"] == top_score)[0]
+ else:
+ human_ids = torch.logical_and(predictions["labels"] == 1,
+ predictions["scores"] > 0.9).nonzero().squeeze(1)
+
+ boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy()
+ masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy()
+
+ M_crop = get_affine_matrix_box(boxes, input_res, input_res)
+ M_crop_tgt_res = get_affine_matrix_box(boxes, tgt_res, tgt_res)
+
+ img_icon_lst = []
+ img_crop_lst = []
+ img_hps_lst = []
+ img_mask_lst = []
+ landmark_lst = []
+ hands_visibility_lst = []
+
+ #print("M_square", M_square, "M_crop": M_crop)
+ uncrop_param = {
+ "ori_shape": [in_height, in_width],
+ "box_shape": [input_res, input_res],
+ "square_shape": [tgt_res, tgt_res],
+ "M_square": M_square,
+ "M_crop": M_crop
+ }
+
+ for idx in range(len(boxes)):
+
+ # mask out the pixels of others
+ if len(masks) > 1:
+ mask_detection = (masks[np.arange(len(masks)) != idx]).max(axis=0)
+ else:
+ mask_detection = masks[0] * 0.
+
+ img_square_rgba = torch.cat(
+ [img_square.squeeze(0).permute(1, 2, 0),
+ torch.tensor(mask_detection < 0.4) * 255],
+ dim=2
+ )
+ img_crop_tgt_res = warp_affine(
+ img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2),
+ M_crop_tgt_res[idx:idx + 1, :2], (tgt_res, ) * 2,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True
+ ).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
+
+ # get accurate person segmentation mask
+ # img_rembg = remove(img_crop, post_process_mask=True, session=new_session("u2net"))
+ img_mask = (get_seg_mask_MODNet(img_crop_tgt_res, modnet=modnet)[0,0, :, :, None].detach().cpu().numpy() > 0.5).astype(np.float32)
+ img_rembg_highres = np.concatenate([img_crop_tgt_res[:, :, :3], (img_mask*255).astype(np.uint8)], axis=-1)
+
+ mean_icon = std_icon = (0.5, 0.5, 0.5)
+ img_np = (img_rembg_highres[..., :3] * img_mask).astype(np.uint8)
+
+
+ img_mask_512 = (F.interpolate(torch.tensor(img_mask).permute(2, 0, 1).unsqueeze(0), size=512, mode='bicubic', align_corners=True)[0][0] > 0.5)
+ img_icon = transform_to_tensor(512, mean_icon, std_icon)(
+ Image.fromarray(img_np)
+ ) * img_mask_512.unsqueeze(0)
+ img_hps = transform_to_tensor(224, IMG_NORM_MEAN,
+ IMG_NORM_STD)(Image.fromarray(img_np))
+ landmarks = get_keypoints(img_np)
+
+ # get hands visibility
+ hands_visibility = [True, True]
+ if landmarks['lhand'][:, -1].mean() == 0.:
+ hands_visibility[0] = False
+ if landmarks['rhand'][:, -1].mean() == 0.:
+ hands_visibility[1] = False
+ hands_visibility_lst.append(hands_visibility)
+
+
+ img_crop_lst.append(torch.tensor(img_rembg_highres).permute(2, 0, 1) / 255.0)
+ img_icon_lst.append(img_icon)
+ img_hps_lst.append(img_hps)
+ img_mask_lst.append(torch.tensor(img_mask_512))
+ landmark_lst.append(landmarks['body'])
+
+ # required image tensors / arrays
+
+ # img_icon (tensor): (-1, 1), [3,512,512]
+ # img_hps (tensor): (-2.11, 2.44), [3,224,224]
+
+ # img_np (array): (0, 255), [512,512,3]
+ # img_rembg (array): (0, 255), [512,512,4]
+ # img_mask (array): (0, 1), [512,512,1]
+ # img_crop (array): (0, 255), [512,512,4]
+
+ return_dict = {
+ "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res]
+ "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res]
+ "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res]
+ "img_raw": img_raw, #[1, 3, H, W]
+ "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res]
+ "uncrop_param": uncrop_param,
+ "landmark": torch.stack(landmark_lst), #[N, 33, 4]
+ "hands_visibility": hands_visibility_lst,
+ }
+
+
+ return return_dict
+
+
+def blend_rgb_norm(norms, data):
+
+ # norms [N, 3, res, res]
+ masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
+ norm_mask = F.interpolate(
+ torch.cat([norms, masks], dim=1).detach(),
+ size=data["uncrop_param"]["box_shape"],
+ mode="bilinear",
+ align_corners=False
+ )
+ final = data["img_raw"].type_as(norm_mask)
+
+ for idx in range(len(norms)):
+
+ norm_pred = (norm_mask[idx:idx + 1, :3, :, :] + 1.0) * 255.0 / 2.0
+ mask_pred = norm_mask[idx:idx + 1, 3:4, :, :].repeat(1, 3, 1, 1)
+
+ norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
+ mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
+
+ final = final * (1.0 - mask_ori) + norm_ori * mask_ori
+
+ return final.detach().cpu()
+
+
+def unwrap(image, uncrop_param, idx):
+
+ device = image.device
+
+ img_square = warp_affine(
+ image,
+ torch.inverse(uncrop_param["M_crop"])[idx:idx + 1, :2].to(device),
+ uncrop_param["square_shape"],
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True
+ )
+
+ img_ori = warp_affine(
+ img_square,
+ torch.inverse(uncrop_param["M_square"])[:, :2].to(device),
+ uncrop_param["ori_shape"],
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True
+ )
+
+ return img_ori
+
+def load_MODNet(weight_path):
+ modnet = MODNet(backbone_pretrained=False).cuda()
+ modnet = modnet.cuda()
+ weights = torch.load(weight_path)
+ weights = {k[7:]: v for k, v in weights.items()}
+ modnet.load_state_dict(weights)
+ modnet.eval()
+ return modnet
+
+def get_seg_mask_MODNet(im, modnet):
+ ref_size = 1024
+
+ # define image to tensor transform
+ im_transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+ ]
+ )
+ # unify image channels to 3
+ im = np.asarray(im)
+ if len(im.shape) == 2:
+ im = im[:, :, None]
+ if im.shape[2] == 1:
+ im = np.repeat(im, 3, axis=2)
+ elif im.shape[2] == 4:
+ im = im[:, :, 0:3]
+
+ # convert image to PyTorch tensor
+ im = Image.fromarray(im)
+ im = im_transform(im)
+
+ # add mini-batch dim
+ im = im[None, :, :, :]
+
+ # resize image for input
+ im_b, im_c, im_h, im_w = im.shape
+ if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
+ if im_w >= im_h:
+ im_rh = ref_size
+ im_rw = int(im_w / im_h * ref_size)
+ elif im_w < im_h:
+ im_rw = ref_size
+ im_rh = int(im_h / im_w * ref_size)
+ else:
+ im_rh = im_h
+ im_rw = im_w
+
+ im_rw = im_rw - im_rw % 32
+ im_rh = im_rh - im_rh % 32
+ im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
+
+ # inference
+ _, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True)
+
+ # resize and save matte
+ matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
+ return matte
\ No newline at end of file
diff --git a/utils/body_utils/lib/common/libmesh/inside_mesh.py b/utils/body_utils/lib/common/libmesh/inside_mesh.py
new file mode 100755
index 0000000..eaac43c
--- /dev/null
+++ b/utils/body_utils/lib/common/libmesh/inside_mesh.py
@@ -0,0 +1,154 @@
+import numpy as np
+from .triangle_hash import TriangleHash as _TriangleHash
+
+
+def check_mesh_contains(mesh, points, hash_resolution=512):
+ intersector = MeshIntersector(mesh, hash_resolution)
+ contains, hole_points = intersector.query(points)
+ return contains, hole_points
+
+
+class MeshIntersector:
+ def __init__(self, mesh, resolution=512):
+ triangles = mesh.vertices[mesh.faces].astype(np.float64)
+ n_tri = triangles.shape[0]
+
+ self.resolution = resolution
+ self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0)
+ self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0)
+ # Tranlate and scale it to [0.5, self.resolution - 0.5]^3
+ self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min)
+ self.translate = 0.5 - self.scale * self.bbox_min
+
+ self._triangles = triangles = self.rescale(triangles)
+ # assert(np.allclose(triangles.reshape(-1, 3).min(0), 0.5))
+ # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5))
+
+ triangles2d = triangles[:, :, :2]
+ self._tri_intersector2d = TriangleIntersector2d(triangles2d, resolution)
+
+ def query(self, points):
+ # Rescale points
+ points = self.rescale(points)
+
+ # placeholder result with no hits we'll fill in later
+ contains = np.zeros(len(points), dtype=np.bool)
+ hole_points = np.zeros(len(points), dtype=np.bool)
+
+ # cull points outside of the axis aligned bounding box
+ # this avoids running ray tests unless points are close
+ inside_aabb = np.all((0 <= points) & (points <= self.resolution), axis=1)
+ if not inside_aabb.any():
+ return contains, hole_points
+
+ # Only consider points inside bounding box
+ mask = inside_aabb
+ points = points[mask]
+
+ # Compute intersection depth and check order
+ points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2])
+
+ triangles_intersect = self._triangles[tri_indices]
+ points_intersect = points[points_indices]
+
+ depth_intersect, abs_n_2 = self.compute_intersection_depth(
+ points_intersect, triangles_intersect
+ )
+
+ # Count number of intersections in both directions
+ smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
+ bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2
+ points_indices_0 = points_indices[smaller_depth]
+ points_indices_1 = points_indices[bigger_depth]
+
+ nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0])
+ nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0])
+
+ # Check if point contained in mesh
+ contains1 = (np.mod(nintersect0, 2) == 1)
+ contains2 = (np.mod(nintersect1, 2) == 1)
+ # if (contains1 != contains2).any():
+ # print('Warning: contains1 != contains2 for some points.')
+ contains[mask] = (contains1 & contains2)
+ hole_points[mask] = np.logical_xor(contains1, contains2)
+ return contains, hole_points
+
+ def compute_intersection_depth(self, points, triangles):
+ t1 = triangles[:, 0, :]
+ t2 = triangles[:, 1, :]
+ t3 = triangles[:, 2, :]
+
+ v1 = t3 - t1
+ v2 = t2 - t1
+ # v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
+ # v2 = v2 / np.linalg.norm(v2, axis=-1, keepdims=True)
+
+ normals = np.cross(v1, v2)
+ alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1)
+
+ n_2 = normals[:, 2]
+ t1_2 = t1[:, 2]
+ s_n_2 = np.sign(n_2)
+ abs_n_2 = np.abs(n_2)
+
+ mask = (abs_n_2 != 0)
+
+ depth_intersect = np.full(points.shape[0], np.nan)
+ depth_intersect[mask] = \
+ t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask]
+
+ # Test the depth:
+ # TODO: remove and put into tests
+ # points_new = np.concatenate([points[:, :2], depth_intersect[:, None]], axis=1)
+ # alpha = (normals * t1).sum(-1)
+ # mask = (depth_intersect == depth_intersect)
+ # assert(np.allclose((points_new[mask] * normals[mask]).sum(-1),
+ # alpha[mask]))
+ return depth_intersect, abs_n_2
+
+ def rescale(self, array):
+ array = self.scale * array + self.translate
+ return array
+
+
+class TriangleIntersector2d:
+ def __init__(self, triangles, resolution=128):
+ self.triangles = triangles
+ self.tri_hash = _TriangleHash(triangles, resolution)
+
+ def query(self, points):
+ point_indices, tri_indices = self.tri_hash.query(points)
+ point_indices = np.array(point_indices, dtype=np.int64)
+ tri_indices = np.array(tri_indices, dtype=np.int64)
+ points = points[point_indices]
+ triangles = self.triangles[tri_indices]
+ mask = self.check_triangles(points, triangles)
+ point_indices = point_indices[mask]
+ tri_indices = tri_indices[mask]
+ return point_indices, tri_indices
+
+ def check_triangles(self, points, triangles):
+ contains = np.zeros(points.shape[0], dtype=np.bool)
+ A = triangles[:, :2] - triangles[:, 2:]
+ A = A.transpose([0, 2, 1])
+ y = points - triangles[:, 2]
+
+ detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0]
+
+ mask = (np.abs(detA) != 0.)
+ A = A[mask]
+ y = y[mask]
+ detA = detA[mask]
+
+ s_detA = np.sign(detA)
+ abs_detA = np.abs(detA)
+
+ u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA
+ v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA
+
+ sum_uv = u + v
+ contains[mask] = (
+ (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) & (0 < sum_uv) &
+ (sum_uv < abs_detA)
+ )
+ return contains
diff --git a/utils/body_utils/lib/common/libmesh/setup.py b/utils/body_utils/lib/common/libmesh/setup.py
new file mode 100755
index 0000000..38ac162
--- /dev/null
+++ b/utils/body_utils/lib/common/libmesh/setup.py
@@ -0,0 +1,5 @@
+from setuptools import setup
+from Cython.Build import cythonize
+import numpy
+
+setup(name='libmesh', ext_modules=cythonize("*.pyx"), include_dirs=[numpy.get_include()])
diff --git a/utils/body_utils/lib/common/libmesh/triangle_hash.cpp b/utils/body_utils/lib/common/libmesh/triangle_hash.cpp
new file mode 100755
index 0000000..a4b0f55
--- /dev/null
+++ b/utils/body_utils/lib/common/libmesh/triangle_hash.cpp
@@ -0,0 +1,24295 @@
+/* Generated by Cython 0.29.33 */
+
+/* BEGIN: Cython Metadata
+{
+ "distutils": {
+ "depends": [],
+ "language": "c++",
+ "name": "triangle_hash",
+ "sources": [
+ "triangle_hash.pyx"
+ ]
+ },
+ "module_name": "triangle_hash"
+}
+END: Cython Metadata */
+
+#ifndef PY_SSIZE_T_CLEAN
+#define PY_SSIZE_T_CLEAN
+#endif /* PY_SSIZE_T_CLEAN */
+#include "Python.h"
+#ifndef Py_PYTHON_H
+ #error Python headers needed to compile C extensions, please install development version of Python.
+#elif PY_VERSION_HEX < 0x02060000 || (0x03000000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x03030000)
+ #error Cython requires Python 2.6+ or Python 3.3+.
+#else
+#define CYTHON_ABI "0_29_33"
+#define CYTHON_HEX_VERSION 0x001D21F0
+#define CYTHON_FUTURE_DIVISION 0
+#include
+#ifndef offsetof
+ #define offsetof(type, member) ( (size_t) & ((type*)0) -> member )
+#endif
+#if !defined(WIN32) && !defined(MS_WINDOWS)
+ #ifndef __stdcall
+ #define __stdcall
+ #endif
+ #ifndef __cdecl
+ #define __cdecl
+ #endif
+ #ifndef __fastcall
+ #define __fastcall
+ #endif
+#endif
+#ifndef DL_IMPORT
+ #define DL_IMPORT(t) t
+#endif
+#ifndef DL_EXPORT
+ #define DL_EXPORT(t) t
+#endif
+#define __PYX_COMMA ,
+#ifndef HAVE_LONG_LONG
+ #if PY_VERSION_HEX >= 0x02070000
+ #define HAVE_LONG_LONG
+ #endif
+#endif
+#ifndef PY_LONG_LONG
+ #define PY_LONG_LONG LONG_LONG
+#endif
+#ifndef Py_HUGE_VAL
+ #define Py_HUGE_VAL HUGE_VAL
+#endif
+#ifdef PYPY_VERSION
+ #define CYTHON_COMPILING_IN_PYPY 1
+ #define CYTHON_COMPILING_IN_PYSTON 0
+ #define CYTHON_COMPILING_IN_CPYTHON 0
+ #define CYTHON_COMPILING_IN_NOGIL 0
+ #undef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 0
+ #undef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 0
+ #if PY_VERSION_HEX < 0x03050000
+ #undef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 0
+ #elif !defined(CYTHON_USE_ASYNC_SLOTS)
+ #define CYTHON_USE_ASYNC_SLOTS 1
+ #endif
+ #undef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 0
+ #undef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 0
+ #undef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #undef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 0
+ #undef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 1
+ #undef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 0
+ #undef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 0
+ #undef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 0
+ #undef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL 0
+ #undef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 0
+ #undef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE 0
+ #undef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS 0
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC
+ #define CYTHON_UPDATE_DESCRIPTOR_DOC 0
+ #endif
+#elif defined(PYSTON_VERSION)
+ #define CYTHON_COMPILING_IN_PYPY 0
+ #define CYTHON_COMPILING_IN_PYSTON 1
+ #define CYTHON_COMPILING_IN_CPYTHON 0
+ #define CYTHON_COMPILING_IN_NOGIL 0
+ #ifndef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 1
+ #endif
+ #undef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 0
+ #undef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 0
+ #undef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 0
+ #ifndef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 1
+ #endif
+ #undef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #undef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 0
+ #ifndef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 0
+ #endif
+ #ifndef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 1
+ #endif
+ #ifndef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 1
+ #endif
+ #undef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 0
+ #undef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL 0
+ #undef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 0
+ #undef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE 0
+ #undef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS 0
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC
+ #define CYTHON_UPDATE_DESCRIPTOR_DOC 0
+ #endif
+#elif defined(PY_NOGIL)
+ #define CYTHON_COMPILING_IN_PYPY 0
+ #define CYTHON_COMPILING_IN_PYSTON 0
+ #define CYTHON_COMPILING_IN_CPYTHON 0
+ #define CYTHON_COMPILING_IN_NOGIL 1
+ #ifndef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 1
+ #endif
+ #undef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 0
+ #ifndef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 1
+ #endif
+ #undef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 0
+ #ifndef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 1
+ #endif
+ #undef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #undef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 0
+ #ifndef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 0
+ #endif
+ #ifndef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 1
+ #endif
+ #ifndef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 1
+ #endif
+ #undef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 0
+ #undef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL 0
+ #ifndef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 1
+ #endif
+ #ifndef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE 1
+ #endif
+ #undef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS 0
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+#else
+ #define CYTHON_COMPILING_IN_PYPY 0
+ #define CYTHON_COMPILING_IN_PYSTON 0
+ #define CYTHON_COMPILING_IN_CPYTHON 1
+ #define CYTHON_COMPILING_IN_NOGIL 0
+ #ifndef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 1
+ #endif
+ #if PY_VERSION_HEX < 0x02070000
+ #undef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 0
+ #elif !defined(CYTHON_USE_PYTYPE_LOOKUP)
+ #define CYTHON_USE_PYTYPE_LOOKUP 1
+ #endif
+ #if PY_MAJOR_VERSION < 3
+ #undef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 0
+ #elif !defined(CYTHON_USE_ASYNC_SLOTS)
+ #define CYTHON_USE_ASYNC_SLOTS 1
+ #endif
+ #if PY_VERSION_HEX < 0x02070000
+ #undef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 0
+ #elif !defined(CYTHON_USE_PYLONG_INTERNALS)
+ #define CYTHON_USE_PYLONG_INTERNALS 1
+ #endif
+ #ifndef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 1
+ #endif
+ #ifndef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 1
+ #endif
+ #if PY_VERSION_HEX < 0x030300F0 || PY_VERSION_HEX >= 0x030B00A2
+ #undef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #elif !defined(CYTHON_USE_UNICODE_WRITER)
+ #define CYTHON_USE_UNICODE_WRITER 1
+ #endif
+ #ifndef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 0
+ #endif
+ #ifndef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 1
+ #endif
+ #ifndef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 1
+ #endif
+ #if PY_VERSION_HEX >= 0x030B00A4
+ #undef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 0
+ #elif !defined(CYTHON_FAST_THREAD_STATE)
+ #define CYTHON_FAST_THREAD_STATE 1
+ #endif
+ #ifndef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL (PY_VERSION_HEX < 0x030A0000)
+ #endif
+ #ifndef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT (PY_VERSION_HEX >= 0x03050000)
+ #endif
+ #ifndef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE (PY_VERSION_HEX >= 0x030400a1)
+ #endif
+ #ifndef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS (PY_VERSION_HEX >= 0x030600B1)
+ #endif
+ #if PY_VERSION_HEX >= 0x030B00A4
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+ #elif !defined(CYTHON_USE_EXC_INFO_STACK)
+ #define CYTHON_USE_EXC_INFO_STACK (PY_VERSION_HEX >= 0x030700A3)
+ #endif
+ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC
+ #define CYTHON_UPDATE_DESCRIPTOR_DOC 1
+ #endif
+#endif
+#if !defined(CYTHON_FAST_PYCCALL)
+#define CYTHON_FAST_PYCCALL (CYTHON_FAST_PYCALL && PY_VERSION_HEX >= 0x030600B1)
+#endif
+#if CYTHON_USE_PYLONG_INTERNALS
+ #if PY_MAJOR_VERSION < 3
+ #include "longintrepr.h"
+ #endif
+ #undef SHIFT
+ #undef BASE
+ #undef MASK
+ #ifdef SIZEOF_VOID_P
+ enum { __pyx_check_sizeof_voidp = 1 / (int)(SIZEOF_VOID_P == sizeof(void*)) };
+ #endif
+#endif
+#ifndef __has_attribute
+ #define __has_attribute(x) 0
+#endif
+#ifndef __has_cpp_attribute
+ #define __has_cpp_attribute(x) 0
+#endif
+#ifndef CYTHON_RESTRICT
+ #if defined(__GNUC__)
+ #define CYTHON_RESTRICT __restrict__
+ #elif defined(_MSC_VER) && _MSC_VER >= 1400
+ #define CYTHON_RESTRICT __restrict
+ #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L
+ #define CYTHON_RESTRICT restrict
+ #else
+ #define CYTHON_RESTRICT
+ #endif
+#endif
+#ifndef CYTHON_UNUSED
+# if defined(__GNUC__)
+# if !(defined(__cplusplus)) || (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 4))
+# define CYTHON_UNUSED __attribute__ ((__unused__))
+# else
+# define CYTHON_UNUSED
+# endif
+# elif defined(__ICC) || (defined(__INTEL_COMPILER) && !defined(_MSC_VER))
+# define CYTHON_UNUSED __attribute__ ((__unused__))
+# else
+# define CYTHON_UNUSED
+# endif
+#endif
+#ifndef CYTHON_MAYBE_UNUSED_VAR
+# if defined(__cplusplus)
+ template void CYTHON_MAYBE_UNUSED_VAR( const T& ) { }
+# else
+# define CYTHON_MAYBE_UNUSED_VAR(x) (void)(x)
+# endif
+#endif
+#ifndef CYTHON_NCP_UNUSED
+# if CYTHON_COMPILING_IN_CPYTHON
+# define CYTHON_NCP_UNUSED
+# else
+# define CYTHON_NCP_UNUSED CYTHON_UNUSED
+# endif
+#endif
+#define __Pyx_void_to_None(void_result) ((void)(void_result), Py_INCREF(Py_None), Py_None)
+#ifdef _MSC_VER
+ #ifndef _MSC_STDINT_H_
+ #if _MSC_VER < 1300
+ typedef unsigned char uint8_t;
+ typedef unsigned int uint32_t;
+ #else
+ typedef unsigned __int8 uint8_t;
+ typedef unsigned __int32 uint32_t;
+ #endif
+ #endif
+#else
+ #include
+#endif
+#ifndef CYTHON_FALLTHROUGH
+ #if defined(__cplusplus) && __cplusplus >= 201103L
+ #if __has_cpp_attribute(fallthrough)
+ #define CYTHON_FALLTHROUGH [[fallthrough]]
+ #elif __has_cpp_attribute(clang::fallthrough)
+ #define CYTHON_FALLTHROUGH [[clang::fallthrough]]
+ #elif __has_cpp_attribute(gnu::fallthrough)
+ #define CYTHON_FALLTHROUGH [[gnu::fallthrough]]
+ #endif
+ #endif
+ #ifndef CYTHON_FALLTHROUGH
+ #if __has_attribute(fallthrough)
+ #define CYTHON_FALLTHROUGH __attribute__((fallthrough))
+ #else
+ #define CYTHON_FALLTHROUGH
+ #endif
+ #endif
+ #if defined(__clang__ ) && defined(__apple_build_version__)
+ #if __apple_build_version__ < 7000000
+ #undef CYTHON_FALLTHROUGH
+ #define CYTHON_FALLTHROUGH
+ #endif
+ #endif
+#endif
+
+#ifndef __cplusplus
+ #error "Cython files generated with the C++ option must be compiled with a C++ compiler."
+#endif
+#ifndef CYTHON_INLINE
+ #if defined(__clang__)
+ #define CYTHON_INLINE __inline__ __attribute__ ((__unused__))
+ #else
+ #define CYTHON_INLINE inline
+ #endif
+#endif
+template
+void __Pyx_call_destructor(T& x) {
+ x.~T();
+}
+template
+class __Pyx_FakeReference {
+ public:
+ __Pyx_FakeReference() : ptr(NULL) { }
+ __Pyx_FakeReference(const T& ref) : ptr(const_cast(&ref)) { }
+ T *operator->() { return ptr; }
+ T *operator&() { return ptr; }
+ operator T&() { return *ptr; }
+ template bool operator ==(U other) { return *ptr == other; }
+ template bool operator !=(U other) { return *ptr != other; }
+ private:
+ T *ptr;
+};
+
+#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX < 0x02070600 && !defined(Py_OptimizeFlag)
+ #define Py_OptimizeFlag 0
+#endif
+#define __PYX_BUILD_PY_SSIZE_T "n"
+#define CYTHON_FORMAT_SSIZE_T "z"
+#if PY_MAJOR_VERSION < 3
+ #define __Pyx_BUILTIN_MODULE_NAME "__builtin__"
+ #define __Pyx_PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\
+ PyCode_New(a+k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)
+ #define __Pyx_DefaultClassType PyClass_Type
+#else
+ #define __Pyx_BUILTIN_MODULE_NAME "builtins"
+ #define __Pyx_DefaultClassType PyType_Type
+#if PY_VERSION_HEX >= 0x030B00A1
+ static CYTHON_INLINE PyCodeObject* __Pyx_PyCode_New(int a, int k, int l, int s, int f,
+ PyObject *code, PyObject *c, PyObject* n, PyObject *v,
+ PyObject *fv, PyObject *cell, PyObject* fn,
+ PyObject *name, int fline, PyObject *lnos) {
+ PyObject *kwds=NULL, *argcount=NULL, *posonlyargcount=NULL, *kwonlyargcount=NULL;
+ PyObject *nlocals=NULL, *stacksize=NULL, *flags=NULL, *replace=NULL, *call_result=NULL, *empty=NULL;
+ const char *fn_cstr=NULL;
+ const char *name_cstr=NULL;
+ PyCodeObject* co=NULL;
+ PyObject *type, *value, *traceback;
+ PyErr_Fetch(&type, &value, &traceback);
+ if (!(kwds=PyDict_New())) goto end;
+ if (!(argcount=PyLong_FromLong(a))) goto end;
+ if (PyDict_SetItemString(kwds, "co_argcount", argcount) != 0) goto end;
+ if (!(posonlyargcount=PyLong_FromLong(0))) goto end;
+ if (PyDict_SetItemString(kwds, "co_posonlyargcount", posonlyargcount) != 0) goto end;
+ if (!(kwonlyargcount=PyLong_FromLong(k))) goto end;
+ if (PyDict_SetItemString(kwds, "co_kwonlyargcount", kwonlyargcount) != 0) goto end;
+ if (!(nlocals=PyLong_FromLong(l))) goto end;
+ if (PyDict_SetItemString(kwds, "co_nlocals", nlocals) != 0) goto end;
+ if (!(stacksize=PyLong_FromLong(s))) goto end;
+ if (PyDict_SetItemString(kwds, "co_stacksize", stacksize) != 0) goto end;
+ if (!(flags=PyLong_FromLong(f))) goto end;
+ if (PyDict_SetItemString(kwds, "co_flags", flags) != 0) goto end;
+ if (PyDict_SetItemString(kwds, "co_code", code) != 0) goto end;
+ if (PyDict_SetItemString(kwds, "co_consts", c) != 0) goto end;
+ if (PyDict_SetItemString(kwds, "co_names", n) != 0) goto end;
+ if (PyDict_SetItemString(kwds, "co_varnames", v) != 0) goto end;
+ if (PyDict_SetItemString(kwds, "co_freevars", fv) != 0) goto end;
+ if (PyDict_SetItemString(kwds, "co_cellvars", cell) != 0) goto end;
+ if (PyDict_SetItemString(kwds, "co_linetable", lnos) != 0) goto end;
+ if (!(fn_cstr=PyUnicode_AsUTF8AndSize(fn, NULL))) goto end;
+ if (!(name_cstr=PyUnicode_AsUTF8AndSize(name, NULL))) goto end;
+ if (!(co = PyCode_NewEmpty(fn_cstr, name_cstr, fline))) goto end;
+ if (!(replace = PyObject_GetAttrString((PyObject*)co, "replace"))) goto cleanup_code_too;
+ if (!(empty = PyTuple_New(0))) goto cleanup_code_too; // unfortunately __pyx_empty_tuple isn't available here
+ if (!(call_result = PyObject_Call(replace, empty, kwds))) goto cleanup_code_too;
+ Py_XDECREF((PyObject*)co);
+ co = (PyCodeObject*)call_result;
+ call_result = NULL;
+ if (0) {
+ cleanup_code_too:
+ Py_XDECREF((PyObject*)co);
+ co = NULL;
+ }
+ end:
+ Py_XDECREF(kwds);
+ Py_XDECREF(argcount);
+ Py_XDECREF(posonlyargcount);
+ Py_XDECREF(kwonlyargcount);
+ Py_XDECREF(nlocals);
+ Py_XDECREF(stacksize);
+ Py_XDECREF(replace);
+ Py_XDECREF(call_result);
+ Py_XDECREF(empty);
+ if (type) {
+ PyErr_Restore(type, value, traceback);
+ }
+ return co;
+ }
+#else
+ #define __Pyx_PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\
+ PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)
+#endif
+ #define __Pyx_DefaultClassType PyType_Type
+#endif
+#ifndef Py_TPFLAGS_CHECKTYPES
+ #define Py_TPFLAGS_CHECKTYPES 0
+#endif
+#ifndef Py_TPFLAGS_HAVE_INDEX
+ #define Py_TPFLAGS_HAVE_INDEX 0
+#endif
+#ifndef Py_TPFLAGS_HAVE_NEWBUFFER
+ #define Py_TPFLAGS_HAVE_NEWBUFFER 0
+#endif
+#ifndef Py_TPFLAGS_HAVE_FINALIZE
+ #define Py_TPFLAGS_HAVE_FINALIZE 0
+#endif
+#ifndef METH_STACKLESS
+ #define METH_STACKLESS 0
+#endif
+#if PY_VERSION_HEX <= 0x030700A3 || !defined(METH_FASTCALL)
+ #ifndef METH_FASTCALL
+ #define METH_FASTCALL 0x80
+ #endif
+ typedef PyObject *(*__Pyx_PyCFunctionFast) (PyObject *self, PyObject *const *args, Py_ssize_t nargs);
+ typedef PyObject *(*__Pyx_PyCFunctionFastWithKeywords) (PyObject *self, PyObject *const *args,
+ Py_ssize_t nargs, PyObject *kwnames);
+#else
+ #define __Pyx_PyCFunctionFast _PyCFunctionFast
+ #define __Pyx_PyCFunctionFastWithKeywords _PyCFunctionFastWithKeywords
+#endif
+#if CYTHON_FAST_PYCCALL
+#define __Pyx_PyFastCFunction_Check(func)\
+ ((PyCFunction_Check(func) && (METH_FASTCALL == (PyCFunction_GET_FLAGS(func) & ~(METH_CLASS | METH_STATIC | METH_COEXIST | METH_KEYWORDS | METH_STACKLESS)))))
+#else
+#define __Pyx_PyFastCFunction_Check(func) 0
+#endif
+#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Malloc)
+ #define PyObject_Malloc(s) PyMem_Malloc(s)
+ #define PyObject_Free(p) PyMem_Free(p)
+ #define PyObject_Realloc(p) PyMem_Realloc(p)
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030400A1
+ #define PyMem_RawMalloc(n) PyMem_Malloc(n)
+ #define PyMem_RawRealloc(p, n) PyMem_Realloc(p, n)
+ #define PyMem_RawFree(p) PyMem_Free(p)
+#endif
+#if CYTHON_COMPILING_IN_PYSTON
+ #define __Pyx_PyCode_HasFreeVars(co) PyCode_HasFreeVars(co)
+ #define __Pyx_PyFrame_SetLineNumber(frame, lineno) PyFrame_SetLineNumber(frame, lineno)
+#else
+ #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0)
+ #define __Pyx_PyFrame_SetLineNumber(frame, lineno) (frame)->f_lineno = (lineno)
+#endif
+#if !CYTHON_FAST_THREAD_STATE || PY_VERSION_HEX < 0x02070000
+ #define __Pyx_PyThreadState_Current PyThreadState_GET()
+#elif PY_VERSION_HEX >= 0x03060000
+ #define __Pyx_PyThreadState_Current _PyThreadState_UncheckedGet()
+#elif PY_VERSION_HEX >= 0x03000000
+ #define __Pyx_PyThreadState_Current PyThreadState_GET()
+#else
+ #define __Pyx_PyThreadState_Current _PyThreadState_Current
+#endif
+#if PY_VERSION_HEX < 0x030700A2 && !defined(PyThread_tss_create) && !defined(Py_tss_NEEDS_INIT)
+#include "pythread.h"
+#define Py_tss_NEEDS_INIT 0
+typedef int Py_tss_t;
+static CYTHON_INLINE int PyThread_tss_create(Py_tss_t *key) {
+ *key = PyThread_create_key();
+ return 0;
+}
+static CYTHON_INLINE Py_tss_t * PyThread_tss_alloc(void) {
+ Py_tss_t *key = (Py_tss_t *)PyObject_Malloc(sizeof(Py_tss_t));
+ *key = Py_tss_NEEDS_INIT;
+ return key;
+}
+static CYTHON_INLINE void PyThread_tss_free(Py_tss_t *key) {
+ PyObject_Free(key);
+}
+static CYTHON_INLINE int PyThread_tss_is_created(Py_tss_t *key) {
+ return *key != Py_tss_NEEDS_INIT;
+}
+static CYTHON_INLINE void PyThread_tss_delete(Py_tss_t *key) {
+ PyThread_delete_key(*key);
+ *key = Py_tss_NEEDS_INIT;
+}
+static CYTHON_INLINE int PyThread_tss_set(Py_tss_t *key, void *value) {
+ return PyThread_set_key_value(*key, value);
+}
+static CYTHON_INLINE void * PyThread_tss_get(Py_tss_t *key) {
+ return PyThread_get_key_value(*key);
+}
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON || defined(_PyDict_NewPresized)
+#define __Pyx_PyDict_NewPresized(n) ((n <= 8) ? PyDict_New() : _PyDict_NewPresized(n))
+#else
+#define __Pyx_PyDict_NewPresized(n) PyDict_New()
+#endif
+#if PY_MAJOR_VERSION >= 3 || CYTHON_FUTURE_DIVISION
+ #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y)
+ #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceTrueDivide(x,y)
+#else
+ #define __Pyx_PyNumber_Divide(x,y) PyNumber_Divide(x,y)
+ #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceDivide(x,y)
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 && CYTHON_USE_UNICODE_INTERNALS
+#define __Pyx_PyDict_GetItemStr(dict, name) _PyDict_GetItem_KnownHash(dict, name, ((PyASCIIObject *) name)->hash)
+#else
+#define __Pyx_PyDict_GetItemStr(dict, name) PyDict_GetItem(dict, name)
+#endif
+#if PY_VERSION_HEX > 0x03030000 && defined(PyUnicode_KIND)
+ #define CYTHON_PEP393_ENABLED 1
+ #if PY_VERSION_HEX >= 0x030C0000
+ #define __Pyx_PyUnicode_READY(op) (0)
+ #else
+ #define __Pyx_PyUnicode_READY(op) (likely(PyUnicode_IS_READY(op)) ?\
+ 0 : _PyUnicode_Ready((PyObject *)(op)))
+ #endif
+ #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_LENGTH(u)
+ #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_READ_CHAR(u, i)
+ #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) PyUnicode_MAX_CHAR_VALUE(u)
+ #define __Pyx_PyUnicode_KIND(u) PyUnicode_KIND(u)
+ #define __Pyx_PyUnicode_DATA(u) PyUnicode_DATA(u)
+ #define __Pyx_PyUnicode_READ(k, d, i) PyUnicode_READ(k, d, i)
+ #define __Pyx_PyUnicode_WRITE(k, d, i, ch) PyUnicode_WRITE(k, d, i, ch)
+ #if PY_VERSION_HEX >= 0x030C0000
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_LENGTH(u))
+ #else
+ #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x03090000
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : ((PyCompactUnicodeObject *)(u))->wstr_length))
+ #else
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : PyUnicode_GET_SIZE(u)))
+ #endif
+ #endif
+#else
+ #define CYTHON_PEP393_ENABLED 0
+ #define PyUnicode_1BYTE_KIND 1
+ #define PyUnicode_2BYTE_KIND 2
+ #define PyUnicode_4BYTE_KIND 4
+ #define __Pyx_PyUnicode_READY(op) (0)
+ #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_SIZE(u)
+ #define __Pyx_PyUnicode_READ_CHAR(u, i) ((Py_UCS4)(PyUnicode_AS_UNICODE(u)[i]))
+ #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((sizeof(Py_UNICODE) == 2) ? 65535 : 1114111)
+ #define __Pyx_PyUnicode_KIND(u) (sizeof(Py_UNICODE))
+ #define __Pyx_PyUnicode_DATA(u) ((void*)PyUnicode_AS_UNICODE(u))
+ #define __Pyx_PyUnicode_READ(k, d, i) ((void)(k), (Py_UCS4)(((Py_UNICODE*)d)[i]))
+ #define __Pyx_PyUnicode_WRITE(k, d, i, ch) (((void)(k)), ((Py_UNICODE*)d)[i] = ch)
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_SIZE(u))
+#endif
+#if CYTHON_COMPILING_IN_PYPY
+ #define __Pyx_PyUnicode_Concat(a, b) PyNumber_Add(a, b)
+ #define __Pyx_PyUnicode_ConcatSafe(a, b) PyNumber_Add(a, b)
+#else
+ #define __Pyx_PyUnicode_Concat(a, b) PyUnicode_Concat(a, b)
+ #define __Pyx_PyUnicode_ConcatSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ?\
+ PyNumber_Add(a, b) : __Pyx_PyUnicode_Concat(a, b))
+#endif
+#if CYTHON_COMPILING_IN_PYPY && !defined(PyUnicode_Contains)
+ #define PyUnicode_Contains(u, s) PySequence_Contains(u, s)
+#endif
+#if CYTHON_COMPILING_IN_PYPY && !defined(PyByteArray_Check)
+ #define PyByteArray_Check(obj) PyObject_TypeCheck(obj, &PyByteArray_Type)
+#endif
+#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Format)
+ #define PyObject_Format(obj, fmt) PyObject_CallMethod(obj, "__format__", "O", fmt)
+#endif
+#define __Pyx_PyString_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyString_Check(b) && !PyString_CheckExact(b)))) ? PyNumber_Remainder(a, b) : __Pyx_PyString_Format(a, b))
+#define __Pyx_PyUnicode_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyUnicode_Check(b) && !PyUnicode_CheckExact(b)))) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b))
+#if PY_MAJOR_VERSION >= 3
+ #define __Pyx_PyString_Format(a, b) PyUnicode_Format(a, b)
+#else
+ #define __Pyx_PyString_Format(a, b) PyString_Format(a, b)
+#endif
+#if PY_MAJOR_VERSION < 3 && !defined(PyObject_ASCII)
+ #define PyObject_ASCII(o) PyObject_Repr(o)
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define PyBaseString_Type PyUnicode_Type
+ #define PyStringObject PyUnicodeObject
+ #define PyString_Type PyUnicode_Type
+ #define PyString_Check PyUnicode_Check
+ #define PyString_CheckExact PyUnicode_CheckExact
+#ifndef PyObject_Unicode
+ #define PyObject_Unicode PyObject_Str
+#endif
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define __Pyx_PyBaseString_Check(obj) PyUnicode_Check(obj)
+ #define __Pyx_PyBaseString_CheckExact(obj) PyUnicode_CheckExact(obj)
+#else
+ #define __Pyx_PyBaseString_Check(obj) (PyString_Check(obj) || PyUnicode_Check(obj))
+ #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj))
+#endif
+#ifndef PySet_CheckExact
+ #define PySet_CheckExact(obj) (Py_TYPE(obj) == &PySet_Type)
+#endif
+#if PY_VERSION_HEX >= 0x030900A4
+ #define __Pyx_SET_REFCNT(obj, refcnt) Py_SET_REFCNT(obj, refcnt)
+ #define __Pyx_SET_SIZE(obj, size) Py_SET_SIZE(obj, size)
+#else
+ #define __Pyx_SET_REFCNT(obj, refcnt) Py_REFCNT(obj) = (refcnt)
+ #define __Pyx_SET_SIZE(obj, size) Py_SIZE(obj) = (size)
+#endif
+#if CYTHON_ASSUME_SAFE_MACROS
+ #define __Pyx_PySequence_SIZE(seq) Py_SIZE(seq)
+#else
+ #define __Pyx_PySequence_SIZE(seq) PySequence_Size(seq)
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define PyIntObject PyLongObject
+ #define PyInt_Type PyLong_Type
+ #define PyInt_Check(op) PyLong_Check(op)
+ #define PyInt_CheckExact(op) PyLong_CheckExact(op)
+ #define PyInt_FromString PyLong_FromString
+ #define PyInt_FromUnicode PyLong_FromUnicode
+ #define PyInt_FromLong PyLong_FromLong
+ #define PyInt_FromSize_t PyLong_FromSize_t
+ #define PyInt_FromSsize_t PyLong_FromSsize_t
+ #define PyInt_AsLong PyLong_AsLong
+ #define PyInt_AS_LONG PyLong_AS_LONG
+ #define PyInt_AsSsize_t PyLong_AsSsize_t
+ #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask
+ #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask
+ #define PyNumber_Int PyNumber_Long
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define PyBoolObject PyLongObject
+#endif
+#if PY_MAJOR_VERSION >= 3 && CYTHON_COMPILING_IN_PYPY
+ #ifndef PyUnicode_InternFromString
+ #define PyUnicode_InternFromString(s) PyUnicode_FromString(s)
+ #endif
+#endif
+#if PY_VERSION_HEX < 0x030200A4
+ typedef long Py_hash_t;
+ #define __Pyx_PyInt_FromHash_t PyInt_FromLong
+ #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsHash_t
+#else
+ #define __Pyx_PyInt_FromHash_t PyInt_FromSsize_t
+ #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsSsize_t
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define __Pyx_PyMethod_New(func, self, klass) ((self) ? ((void)(klass), PyMethod_New(func, self)) : __Pyx_NewRef(func))
+#else
+ #define __Pyx_PyMethod_New(func, self, klass) PyMethod_New(func, self, klass)
+#endif
+#if CYTHON_USE_ASYNC_SLOTS
+ #if PY_VERSION_HEX >= 0x030500B1
+ #define __Pyx_PyAsyncMethodsStruct PyAsyncMethods
+ #define __Pyx_PyType_AsAsync(obj) (Py_TYPE(obj)->tp_as_async)
+ #else
+ #define __Pyx_PyType_AsAsync(obj) ((__Pyx_PyAsyncMethodsStruct*) (Py_TYPE(obj)->tp_reserved))
+ #endif
+#else
+ #define __Pyx_PyType_AsAsync(obj) NULL
+#endif
+#ifndef __Pyx_PyAsyncMethodsStruct
+ typedef struct {
+ unaryfunc am_await;
+ unaryfunc am_aiter;
+ unaryfunc am_anext;
+ } __Pyx_PyAsyncMethodsStruct;
+#endif
+
+#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS)
+ #if !defined(_USE_MATH_DEFINES)
+ #define _USE_MATH_DEFINES
+ #endif
+#endif
+#include
+#ifdef NAN
+#define __PYX_NAN() ((float) NAN)
+#else
+static CYTHON_INLINE float __PYX_NAN() {
+ float value;
+ memset(&value, 0xFF, sizeof(value));
+ return value;
+}
+#endif
+#if defined(__CYGWIN__) && defined(_LDBL_EQ_DBL)
+#define __Pyx_truncl trunc
+#else
+#define __Pyx_truncl truncl
+#endif
+
+#define __PYX_MARK_ERR_POS(f_index, lineno) \
+ { __pyx_filename = __pyx_f[f_index]; (void)__pyx_filename; __pyx_lineno = lineno; (void)__pyx_lineno; __pyx_clineno = __LINE__; (void)__pyx_clineno; }
+#define __PYX_ERR(f_index, lineno, Ln_error) \
+ { __PYX_MARK_ERR_POS(f_index, lineno) goto Ln_error; }
+
+#ifndef __PYX_EXTERN_C
+ #ifdef __cplusplus
+ #define __PYX_EXTERN_C extern "C"
+ #else
+ #define __PYX_EXTERN_C extern
+ #endif
+#endif
+
+#define __PYX_HAVE__triangle_hash
+#define __PYX_HAVE_API__triangle_hash
+/* Early includes */
+#include
+#include
+#include "numpy/arrayobject.h"
+#include "numpy/ndarrayobject.h"
+#include "numpy/ndarraytypes.h"
+#include "numpy/arrayscalars.h"
+#include "numpy/ufuncobject.h"
+
+ /* NumPy API declarations from "numpy/__init__.pxd" */
+
+#include "ios"
+#include "new"
+#include "stdexcept"
+#include "typeinfo"
+#include
+#include
+#include "pythread.h"
+#include
+#include "pystate.h"
+#ifdef _OPENMP
+#include
+#endif /* _OPENMP */
+
+#if defined(PYREX_WITHOUT_ASSERTIONS) && !defined(CYTHON_WITHOUT_ASSERTIONS)
+#define CYTHON_WITHOUT_ASSERTIONS
+#endif
+
+typedef struct {PyObject **p; const char *s; const Py_ssize_t n; const char* encoding;
+ const char is_unicode; const char is_str; const char intern; } __Pyx_StringTabEntry;
+
+#define __PYX_DEFAULT_STRING_ENCODING_IS_ASCII 0
+#define __PYX_DEFAULT_STRING_ENCODING_IS_UTF8 0
+#define __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT (PY_MAJOR_VERSION >= 3 && __PYX_DEFAULT_STRING_ENCODING_IS_UTF8)
+#define __PYX_DEFAULT_STRING_ENCODING ""
+#define __Pyx_PyObject_FromString __Pyx_PyBytes_FromString
+#define __Pyx_PyObject_FromStringAndSize __Pyx_PyBytes_FromStringAndSize
+#define __Pyx_uchar_cast(c) ((unsigned char)c)
+#define __Pyx_long_cast(x) ((long)x)
+#define __Pyx_fits_Py_ssize_t(v, type, is_signed) (\
+ (sizeof(type) < sizeof(Py_ssize_t)) ||\
+ (sizeof(type) > sizeof(Py_ssize_t) &&\
+ likely(v < (type)PY_SSIZE_T_MAX ||\
+ v == (type)PY_SSIZE_T_MAX) &&\
+ (!is_signed || likely(v > (type)PY_SSIZE_T_MIN ||\
+ v == (type)PY_SSIZE_T_MIN))) ||\
+ (sizeof(type) == sizeof(Py_ssize_t) &&\
+ (is_signed || likely(v < (type)PY_SSIZE_T_MAX ||\
+ v == (type)PY_SSIZE_T_MAX))) )
+static CYTHON_INLINE int __Pyx_is_valid_index(Py_ssize_t i, Py_ssize_t limit) {
+ return (size_t) i < (size_t) limit;
+}
+#if defined (__cplusplus) && __cplusplus >= 201103L
+ #include
+ #define __Pyx_sst_abs(value) std::abs(value)
+#elif SIZEOF_INT >= SIZEOF_SIZE_T
+ #define __Pyx_sst_abs(value) abs(value)
+#elif SIZEOF_LONG >= SIZEOF_SIZE_T
+ #define __Pyx_sst_abs(value) labs(value)
+#elif defined (_MSC_VER)
+ #define __Pyx_sst_abs(value) ((Py_ssize_t)_abs64(value))
+#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L
+ #define __Pyx_sst_abs(value) llabs(value)
+#elif defined (__GNUC__)
+ #define __Pyx_sst_abs(value) __builtin_llabs(value)
+#else
+ #define __Pyx_sst_abs(value) ((value<0) ? -value : value)
+#endif
+static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject*);
+static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject*, Py_ssize_t* length);
+#define __Pyx_PyByteArray_FromString(s) PyByteArray_FromStringAndSize((const char*)s, strlen((const char*)s))
+#define __Pyx_PyByteArray_FromStringAndSize(s, l) PyByteArray_FromStringAndSize((const char*)s, l)
+#define __Pyx_PyBytes_FromString PyBytes_FromString
+#define __Pyx_PyBytes_FromStringAndSize PyBytes_FromStringAndSize
+static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char*);
+#if PY_MAJOR_VERSION < 3
+ #define __Pyx_PyStr_FromString __Pyx_PyBytes_FromString
+ #define __Pyx_PyStr_FromStringAndSize __Pyx_PyBytes_FromStringAndSize
+#else
+ #define __Pyx_PyStr_FromString __Pyx_PyUnicode_FromString
+ #define __Pyx_PyStr_FromStringAndSize __Pyx_PyUnicode_FromStringAndSize
+#endif
+#define __Pyx_PyBytes_AsWritableString(s) ((char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsWritableSString(s) ((signed char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsWritableUString(s) ((unsigned char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsString(s) ((const char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsSString(s) ((const signed char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsUString(s) ((const unsigned char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyObject_AsWritableString(s) ((char*) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_AsWritableSString(s) ((signed char*) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_AsWritableUString(s) ((unsigned char*) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_AsSString(s) ((const signed char*) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_AsUString(s) ((const unsigned char*) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_FromCString(s) __Pyx_PyObject_FromString((const char*)s)
+#define __Pyx_PyBytes_FromCString(s) __Pyx_PyBytes_FromString((const char*)s)
+#define __Pyx_PyByteArray_FromCString(s) __Pyx_PyByteArray_FromString((const char*)s)
+#define __Pyx_PyStr_FromCString(s) __Pyx_PyStr_FromString((const char*)s)
+#define __Pyx_PyUnicode_FromCString(s) __Pyx_PyUnicode_FromString((const char*)s)
+static CYTHON_INLINE size_t __Pyx_Py_UNICODE_strlen(const Py_UNICODE *u) {
+ const Py_UNICODE *u_end = u;
+ while (*u_end++) ;
+ return (size_t)(u_end - u - 1);
+}
+#define __Pyx_PyUnicode_FromUnicode(u) PyUnicode_FromUnicode(u, __Pyx_Py_UNICODE_strlen(u))
+#define __Pyx_PyUnicode_FromUnicodeAndLength PyUnicode_FromUnicode
+#define __Pyx_PyUnicode_AsUnicode PyUnicode_AsUnicode
+#define __Pyx_NewRef(obj) (Py_INCREF(obj), obj)
+#define __Pyx_Owned_Py_None(b) __Pyx_NewRef(Py_None)
+static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b);
+static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject*);
+static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject*);
+static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x);
+#define __Pyx_PySequence_Tuple(obj)\
+ (likely(PyTuple_CheckExact(obj)) ? __Pyx_NewRef(obj) : PySequence_Tuple(obj))
+static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject*);
+static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t);
+static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject*);
+#if CYTHON_ASSUME_SAFE_MACROS
+#define __pyx_PyFloat_AsDouble(x) (PyFloat_CheckExact(x) ? PyFloat_AS_DOUBLE(x) : PyFloat_AsDouble(x))
+#else
+#define __pyx_PyFloat_AsDouble(x) PyFloat_AsDouble(x)
+#endif
+#define __pyx_PyFloat_AsFloat(x) ((float) __pyx_PyFloat_AsDouble(x))
+#if PY_MAJOR_VERSION >= 3
+#define __Pyx_PyNumber_Int(x) (PyLong_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Long(x))
+#else
+#define __Pyx_PyNumber_Int(x) (PyInt_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Int(x))
+#endif
+#define __Pyx_PyNumber_Float(x) (PyFloat_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Float(x))
+#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII
+static int __Pyx_sys_getdefaultencoding_not_ascii;
+static int __Pyx_init_sys_getdefaultencoding_params(void) {
+ PyObject* sys;
+ PyObject* default_encoding = NULL;
+ PyObject* ascii_chars_u = NULL;
+ PyObject* ascii_chars_b = NULL;
+ const char* default_encoding_c;
+ sys = PyImport_ImportModule("sys");
+ if (!sys) goto bad;
+ default_encoding = PyObject_CallMethod(sys, (char*) "getdefaultencoding", NULL);
+ Py_DECREF(sys);
+ if (!default_encoding) goto bad;
+ default_encoding_c = PyBytes_AsString(default_encoding);
+ if (!default_encoding_c) goto bad;
+ if (strcmp(default_encoding_c, "ascii") == 0) {
+ __Pyx_sys_getdefaultencoding_not_ascii = 0;
+ } else {
+ char ascii_chars[128];
+ int c;
+ for (c = 0; c < 128; c++) {
+ ascii_chars[c] = c;
+ }
+ __Pyx_sys_getdefaultencoding_not_ascii = 1;
+ ascii_chars_u = PyUnicode_DecodeASCII(ascii_chars, 128, NULL);
+ if (!ascii_chars_u) goto bad;
+ ascii_chars_b = PyUnicode_AsEncodedString(ascii_chars_u, default_encoding_c, NULL);
+ if (!ascii_chars_b || !PyBytes_Check(ascii_chars_b) || memcmp(ascii_chars, PyBytes_AS_STRING(ascii_chars_b), 128) != 0) {
+ PyErr_Format(
+ PyExc_ValueError,
+ "This module compiled with c_string_encoding=ascii, but default encoding '%.200s' is not a superset of ascii.",
+ default_encoding_c);
+ goto bad;
+ }
+ Py_DECREF(ascii_chars_u);
+ Py_DECREF(ascii_chars_b);
+ }
+ Py_DECREF(default_encoding);
+ return 0;
+bad:
+ Py_XDECREF(default_encoding);
+ Py_XDECREF(ascii_chars_u);
+ Py_XDECREF(ascii_chars_b);
+ return -1;
+}
+#endif
+#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT && PY_MAJOR_VERSION >= 3
+#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_DecodeUTF8(c_str, size, NULL)
+#else
+#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_Decode(c_str, size, __PYX_DEFAULT_STRING_ENCODING, NULL)
+#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT
+static char* __PYX_DEFAULT_STRING_ENCODING;
+static int __Pyx_init_sys_getdefaultencoding_params(void) {
+ PyObject* sys;
+ PyObject* default_encoding = NULL;
+ char* default_encoding_c;
+ sys = PyImport_ImportModule("sys");
+ if (!sys) goto bad;
+ default_encoding = PyObject_CallMethod(sys, (char*) (const char*) "getdefaultencoding", NULL);
+ Py_DECREF(sys);
+ if (!default_encoding) goto bad;
+ default_encoding_c = PyBytes_AsString(default_encoding);
+ if (!default_encoding_c) goto bad;
+ __PYX_DEFAULT_STRING_ENCODING = (char*) malloc(strlen(default_encoding_c) + 1);
+ if (!__PYX_DEFAULT_STRING_ENCODING) goto bad;
+ strcpy(__PYX_DEFAULT_STRING_ENCODING, default_encoding_c);
+ Py_DECREF(default_encoding);
+ return 0;
+bad:
+ Py_XDECREF(default_encoding);
+ return -1;
+}
+#endif
+#endif
+
+
+/* Test for GCC > 2.95 */
+#if defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95)))
+ #define likely(x) __builtin_expect(!!(x), 1)
+ #define unlikely(x) __builtin_expect(!!(x), 0)
+#else /* !__GNUC__ or GCC < 2.95 */
+ #define likely(x) (x)
+ #define unlikely(x) (x)
+#endif /* __GNUC__ */
+static CYTHON_INLINE void __Pyx_pretend_to_initialize(void* ptr) { (void)ptr; }
+
+static PyObject *__pyx_m = NULL;
+static PyObject *__pyx_d;
+static PyObject *__pyx_b;
+static PyObject *__pyx_cython_runtime = NULL;
+static PyObject *__pyx_empty_tuple;
+static PyObject *__pyx_empty_bytes;
+static PyObject *__pyx_empty_unicode;
+static int __pyx_lineno;
+static int __pyx_clineno = 0;
+static const char * __pyx_cfilenm= __FILE__;
+static const char *__pyx_filename;
+
+/* Header.proto */
+#if !defined(CYTHON_CCOMPLEX)
+ #if defined(__cplusplus)
+ #define CYTHON_CCOMPLEX 1
+ #elif defined(_Complex_I)
+ #define CYTHON_CCOMPLEX 1
+ #else
+ #define CYTHON_CCOMPLEX 0
+ #endif
+#endif
+#if CYTHON_CCOMPLEX
+ #ifdef __cplusplus
+ #include
+ #else
+ #include
+ #endif
+#endif
+#if CYTHON_CCOMPLEX && !defined(__cplusplus) && defined(__sun__) && defined(__GNUC__)
+ #undef _Complex_I
+ #define _Complex_I 1.0fj
+#endif
+
+
+static const char *__pyx_f[] = {
+ "triangle_hash.pyx",
+ "stringsource",
+ "__init__.pxd",
+ "type.pxd",
+};
+/* MemviewSliceStruct.proto */
+struct __pyx_memoryview_obj;
+typedef struct {
+ struct __pyx_memoryview_obj *memview;
+ char *data;
+ Py_ssize_t shape[8];
+ Py_ssize_t strides[8];
+ Py_ssize_t suboffsets[8];
+} __Pyx_memviewslice;
+#define __Pyx_MemoryView_Len(m) (m.shape[0])
+
+/* Atomics.proto */
+#include
+#ifndef CYTHON_ATOMICS
+ #define CYTHON_ATOMICS 1
+#endif
+#define __PYX_CYTHON_ATOMICS_ENABLED() CYTHON_ATOMICS
+#define __pyx_atomic_int_type int
+#if CYTHON_ATOMICS && (__GNUC__ >= 5 || (__GNUC__ == 4 &&\
+ (__GNUC_MINOR__ > 1 ||\
+ (__GNUC_MINOR__ == 1 && __GNUC_PATCHLEVEL__ >= 2))))
+ #define __pyx_atomic_incr_aligned(value) __sync_fetch_and_add(value, 1)
+ #define __pyx_atomic_decr_aligned(value) __sync_fetch_and_sub(value, 1)
+ #ifdef __PYX_DEBUG_ATOMICS
+ #warning "Using GNU atomics"
+ #endif
+#elif CYTHON_ATOMICS && defined(_MSC_VER) && CYTHON_COMPILING_IN_NOGIL
+ #include
+ #undef __pyx_atomic_int_type
+ #define __pyx_atomic_int_type long
+ #pragma intrinsic (_InterlockedExchangeAdd)
+ #define __pyx_atomic_incr_aligned(value) _InterlockedExchangeAdd(value, 1)
+ #define __pyx_atomic_decr_aligned(value) _InterlockedExchangeAdd(value, -1)
+ #ifdef __PYX_DEBUG_ATOMICS
+ #pragma message ("Using MSVC atomics")
+ #endif
+#else
+ #undef CYTHON_ATOMICS
+ #define CYTHON_ATOMICS 0
+ #ifdef __PYX_DEBUG_ATOMICS
+ #warning "Not using atomics"
+ #endif
+#endif
+typedef volatile __pyx_atomic_int_type __pyx_atomic_int;
+#if CYTHON_ATOMICS
+ #define __pyx_add_acquisition_count(memview)\
+ __pyx_atomic_incr_aligned(__pyx_get_slice_count_pointer(memview))
+ #define __pyx_sub_acquisition_count(memview)\
+ __pyx_atomic_decr_aligned(__pyx_get_slice_count_pointer(memview))
+#else
+ #define __pyx_add_acquisition_count(memview)\
+ __pyx_add_acquisition_count_locked(__pyx_get_slice_count_pointer(memview), memview->lock)
+ #define __pyx_sub_acquisition_count(memview)\
+ __pyx_sub_acquisition_count_locked(__pyx_get_slice_count_pointer(memview), memview->lock)
+#endif
+
+/* ForceInitThreads.proto */
+#ifndef __PYX_FORCE_INIT_THREADS
+ #define __PYX_FORCE_INIT_THREADS 0
+#endif
+
+/* NoFastGil.proto */
+#define __Pyx_PyGILState_Ensure PyGILState_Ensure
+#define __Pyx_PyGILState_Release PyGILState_Release
+#define __Pyx_FastGIL_Remember()
+#define __Pyx_FastGIL_Forget()
+#define __Pyx_FastGilFuncInit()
+
+/* BufferFormatStructs.proto */
+#define IS_UNSIGNED(type) (((type) -1) > 0)
+struct __Pyx_StructField_;
+#define __PYX_BUF_FLAGS_PACKED_STRUCT (1 << 0)
+typedef struct {
+ const char* name;
+ struct __Pyx_StructField_* fields;
+ size_t size;
+ size_t arraysize[8];
+ int ndim;
+ char typegroup;
+ char is_unsigned;
+ int flags;
+} __Pyx_TypeInfo;
+typedef struct __Pyx_StructField_ {
+ __Pyx_TypeInfo* type;
+ const char* name;
+ size_t offset;
+} __Pyx_StructField;
+typedef struct {
+ __Pyx_StructField* field;
+ size_t parent_offset;
+} __Pyx_BufFmt_StackElem;
+typedef struct {
+ __Pyx_StructField root;
+ __Pyx_BufFmt_StackElem* head;
+ size_t fmt_offset;
+ size_t new_count, enc_count;
+ size_t struct_alignment;
+ int is_complex;
+ char enc_type;
+ char new_packmode;
+ char enc_packmode;
+ char is_valid_array;
+} __Pyx_BufFmt_Context;
+
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":689
+ * # in Cython to enable them only on the right systems.
+ *
+ * ctypedef npy_int8 int8_t # <<<<<<<<<<<<<<
+ * ctypedef npy_int16 int16_t
+ * ctypedef npy_int32 int32_t
+ */
+typedef npy_int8 __pyx_t_5numpy_int8_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":690
+ *
+ * ctypedef npy_int8 int8_t
+ * ctypedef npy_int16 int16_t # <<<<<<<<<<<<<<
+ * ctypedef npy_int32 int32_t
+ * ctypedef npy_int64 int64_t
+ */
+typedef npy_int16 __pyx_t_5numpy_int16_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":691
+ * ctypedef npy_int8 int8_t
+ * ctypedef npy_int16 int16_t
+ * ctypedef npy_int32 int32_t # <<<<<<<<<<<<<<
+ * ctypedef npy_int64 int64_t
+ * #ctypedef npy_int96 int96_t
+ */
+typedef npy_int32 __pyx_t_5numpy_int32_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":692
+ * ctypedef npy_int16 int16_t
+ * ctypedef npy_int32 int32_t
+ * ctypedef npy_int64 int64_t # <<<<<<<<<<<<<<
+ * #ctypedef npy_int96 int96_t
+ * #ctypedef npy_int128 int128_t
+ */
+typedef npy_int64 __pyx_t_5numpy_int64_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":696
+ * #ctypedef npy_int128 int128_t
+ *
+ * ctypedef npy_uint8 uint8_t # <<<<<<<<<<<<<<
+ * ctypedef npy_uint16 uint16_t
+ * ctypedef npy_uint32 uint32_t
+ */
+typedef npy_uint8 __pyx_t_5numpy_uint8_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":697
+ *
+ * ctypedef npy_uint8 uint8_t
+ * ctypedef npy_uint16 uint16_t # <<<<<<<<<<<<<<
+ * ctypedef npy_uint32 uint32_t
+ * ctypedef npy_uint64 uint64_t
+ */
+typedef npy_uint16 __pyx_t_5numpy_uint16_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":698
+ * ctypedef npy_uint8 uint8_t
+ * ctypedef npy_uint16 uint16_t
+ * ctypedef npy_uint32 uint32_t # <<<<<<<<<<<<<<
+ * ctypedef npy_uint64 uint64_t
+ * #ctypedef npy_uint96 uint96_t
+ */
+typedef npy_uint32 __pyx_t_5numpy_uint32_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":699
+ * ctypedef npy_uint16 uint16_t
+ * ctypedef npy_uint32 uint32_t
+ * ctypedef npy_uint64 uint64_t # <<<<<<<<<<<<<<
+ * #ctypedef npy_uint96 uint96_t
+ * #ctypedef npy_uint128 uint128_t
+ */
+typedef npy_uint64 __pyx_t_5numpy_uint64_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":703
+ * #ctypedef npy_uint128 uint128_t
+ *
+ * ctypedef npy_float32 float32_t # <<<<<<<<<<<<<<
+ * ctypedef npy_float64 float64_t
+ * #ctypedef npy_float80 float80_t
+ */
+typedef npy_float32 __pyx_t_5numpy_float32_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":704
+ *
+ * ctypedef npy_float32 float32_t
+ * ctypedef npy_float64 float64_t # <<<<<<<<<<<<<<
+ * #ctypedef npy_float80 float80_t
+ * #ctypedef npy_float128 float128_t
+ */
+typedef npy_float64 __pyx_t_5numpy_float64_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":713
+ * # The int types are mapped a bit surprising --
+ * # numpy.int corresponds to 'l' and numpy.long to 'q'
+ * ctypedef npy_long int_t # <<<<<<<<<<<<<<
+ * ctypedef npy_longlong long_t
+ * ctypedef npy_longlong longlong_t
+ */
+typedef npy_long __pyx_t_5numpy_int_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":714
+ * # numpy.int corresponds to 'l' and numpy.long to 'q'
+ * ctypedef npy_long int_t
+ * ctypedef npy_longlong long_t # <<<<<<<<<<<<<<
+ * ctypedef npy_longlong longlong_t
+ *
+ */
+typedef npy_longlong __pyx_t_5numpy_long_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":715
+ * ctypedef npy_long int_t
+ * ctypedef npy_longlong long_t
+ * ctypedef npy_longlong longlong_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_ulong uint_t
+ */
+typedef npy_longlong __pyx_t_5numpy_longlong_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":717
+ * ctypedef npy_longlong longlong_t
+ *
+ * ctypedef npy_ulong uint_t # <<<<<<<<<<<<<<
+ * ctypedef npy_ulonglong ulong_t
+ * ctypedef npy_ulonglong ulonglong_t
+ */
+typedef npy_ulong __pyx_t_5numpy_uint_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":718
+ *
+ * ctypedef npy_ulong uint_t
+ * ctypedef npy_ulonglong ulong_t # <<<<<<<<<<<<<<
+ * ctypedef npy_ulonglong ulonglong_t
+ *
+ */
+typedef npy_ulonglong __pyx_t_5numpy_ulong_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":719
+ * ctypedef npy_ulong uint_t
+ * ctypedef npy_ulonglong ulong_t
+ * ctypedef npy_ulonglong ulonglong_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_intp intp_t
+ */
+typedef npy_ulonglong __pyx_t_5numpy_ulonglong_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":721
+ * ctypedef npy_ulonglong ulonglong_t
+ *
+ * ctypedef npy_intp intp_t # <<<<<<<<<<<<<<
+ * ctypedef npy_uintp uintp_t
+ *
+ */
+typedef npy_intp __pyx_t_5numpy_intp_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":722
+ *
+ * ctypedef npy_intp intp_t
+ * ctypedef npy_uintp uintp_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_double float_t
+ */
+typedef npy_uintp __pyx_t_5numpy_uintp_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":724
+ * ctypedef npy_uintp uintp_t
+ *
+ * ctypedef npy_double float_t # <<<<<<<<<<<<<<
+ * ctypedef npy_double double_t
+ * ctypedef npy_longdouble longdouble_t
+ */
+typedef npy_double __pyx_t_5numpy_float_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":725
+ *
+ * ctypedef npy_double float_t
+ * ctypedef npy_double double_t # <<<<<<<<<<<<<<
+ * ctypedef npy_longdouble longdouble_t
+ *
+ */
+typedef npy_double __pyx_t_5numpy_double_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":726
+ * ctypedef npy_double float_t
+ * ctypedef npy_double double_t
+ * ctypedef npy_longdouble longdouble_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_cfloat cfloat_t
+ */
+typedef npy_longdouble __pyx_t_5numpy_longdouble_t;
+/* Declarations.proto */
+#if CYTHON_CCOMPLEX
+ #ifdef __cplusplus
+ typedef ::std::complex< float > __pyx_t_float_complex;
+ #else
+ typedef float _Complex __pyx_t_float_complex;
+ #endif
+#else
+ typedef struct { float real, imag; } __pyx_t_float_complex;
+#endif
+static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float, float);
+
+/* Declarations.proto */
+#if CYTHON_CCOMPLEX
+ #ifdef __cplusplus
+ typedef ::std::complex< double > __pyx_t_double_complex;
+ #else
+ typedef double _Complex __pyx_t_double_complex;
+ #endif
+#else
+ typedef struct { double real, imag; } __pyx_t_double_complex;
+#endif
+static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double, double);
+
+
+/*--- Type declarations ---*/
+struct __pyx_obj_13triangle_hash_TriangleHash;
+struct __pyx_array_obj;
+struct __pyx_MemviewEnum_obj;
+struct __pyx_memoryview_obj;
+struct __pyx_memoryviewslice_obj;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":728
+ * ctypedef npy_longdouble longdouble_t
+ *
+ * ctypedef npy_cfloat cfloat_t # <<<<<<<<<<<<<<
+ * ctypedef npy_cdouble cdouble_t
+ * ctypedef npy_clongdouble clongdouble_t
+ */
+typedef npy_cfloat __pyx_t_5numpy_cfloat_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":729
+ *
+ * ctypedef npy_cfloat cfloat_t
+ * ctypedef npy_cdouble cdouble_t # <<<<<<<<<<<<<<
+ * ctypedef npy_clongdouble clongdouble_t
+ *
+ */
+typedef npy_cdouble __pyx_t_5numpy_cdouble_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":730
+ * ctypedef npy_cfloat cfloat_t
+ * ctypedef npy_cdouble cdouble_t
+ * ctypedef npy_clongdouble clongdouble_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_cdouble complex_t
+ */
+typedef npy_clongdouble __pyx_t_5numpy_clongdouble_t;
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":732
+ * ctypedef npy_clongdouble clongdouble_t
+ *
+ * ctypedef npy_cdouble complex_t # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew1(a):
+ */
+typedef npy_cdouble __pyx_t_5numpy_complex_t;
+
+/* "triangle_hash.pyx":9
+ * from libc.math cimport floor, ceil
+ *
+ * cdef class TriangleHash: # <<<<<<<<<<<<<<
+ * cdef vector[vector[int]] spatial_hash
+ * cdef int resolution
+ */
+struct __pyx_obj_13triangle_hash_TriangleHash {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_13triangle_hash_TriangleHash *__pyx_vtab;
+ std::vector > spatial_hash;
+ int resolution;
+};
+
+
+/* "View.MemoryView":106
+ *
+ * @cname("__pyx_array")
+ * cdef class array: # <<<<<<<<<<<<<<
+ *
+ * cdef:
+ */
+struct __pyx_array_obj {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_array *__pyx_vtab;
+ char *data;
+ Py_ssize_t len;
+ char *format;
+ int ndim;
+ Py_ssize_t *_shape;
+ Py_ssize_t *_strides;
+ Py_ssize_t itemsize;
+ PyObject *mode;
+ PyObject *_format;
+ void (*callback_free_data)(void *);
+ int free_data;
+ int dtype_is_object;
+};
+
+
+/* "View.MemoryView":280
+ *
+ * @cname('__pyx_MemviewEnum')
+ * cdef class Enum(object): # <<<<<<<<<<<<<<
+ * cdef object name
+ * def __init__(self, name):
+ */
+struct __pyx_MemviewEnum_obj {
+ PyObject_HEAD
+ PyObject *name;
+};
+
+
+/* "View.MemoryView":331
+ *
+ * @cname('__pyx_memoryview')
+ * cdef class memoryview(object): # <<<<<<<<<<<<<<
+ *
+ * cdef object obj
+ */
+struct __pyx_memoryview_obj {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_memoryview *__pyx_vtab;
+ PyObject *obj;
+ PyObject *_size;
+ PyObject *_array_interface;
+ PyThread_type_lock lock;
+ __pyx_atomic_int acquisition_count[2];
+ __pyx_atomic_int *acquisition_count_aligned_p;
+ Py_buffer view;
+ int flags;
+ int dtype_is_object;
+ __Pyx_TypeInfo *typeinfo;
+};
+
+
+/* "View.MemoryView":967
+ *
+ * @cname('__pyx_memoryviewslice')
+ * cdef class _memoryviewslice(memoryview): # <<<<<<<<<<<<<<
+ * "Internal class for passing memoryview slices to Python"
+ *
+ */
+struct __pyx_memoryviewslice_obj {
+ struct __pyx_memoryview_obj __pyx_base;
+ __Pyx_memviewslice from_slice;
+ PyObject *from_object;
+ PyObject *(*to_object_func)(char *);
+ int (*to_dtype_func)(char *, PyObject *);
+};
+
+
+
+/* "triangle_hash.pyx":9
+ * from libc.math cimport floor, ceil
+ *
+ * cdef class TriangleHash: # <<<<<<<<<<<<<<
+ * cdef vector[vector[int]] spatial_hash
+ * cdef int resolution
+ */
+
+struct __pyx_vtabstruct_13triangle_hash_TriangleHash {
+ int (*_build_hash)(struct __pyx_obj_13triangle_hash_TriangleHash *, __Pyx_memviewslice);
+ PyObject *(*query)(struct __pyx_obj_13triangle_hash_TriangleHash *, __Pyx_memviewslice, int __pyx_skip_dispatch);
+};
+static struct __pyx_vtabstruct_13triangle_hash_TriangleHash *__pyx_vtabptr_13triangle_hash_TriangleHash;
+
+
+/* "View.MemoryView":106
+ *
+ * @cname("__pyx_array")
+ * cdef class array: # <<<<<<<<<<<<<<
+ *
+ * cdef:
+ */
+
+struct __pyx_vtabstruct_array {
+ PyObject *(*get_memview)(struct __pyx_array_obj *);
+};
+static struct __pyx_vtabstruct_array *__pyx_vtabptr_array;
+
+
+/* "View.MemoryView":331
+ *
+ * @cname('__pyx_memoryview')
+ * cdef class memoryview(object): # <<<<<<<<<<<<<<
+ *
+ * cdef object obj
+ */
+
+struct __pyx_vtabstruct_memoryview {
+ char *(*get_item_pointer)(struct __pyx_memoryview_obj *, PyObject *);
+ PyObject *(*is_slice)(struct __pyx_memoryview_obj *, PyObject *);
+ PyObject *(*setitem_slice_assignment)(struct __pyx_memoryview_obj *, PyObject *, PyObject *);
+ PyObject *(*setitem_slice_assign_scalar)(struct __pyx_memoryview_obj *, struct __pyx_memoryview_obj *, PyObject *);
+ PyObject *(*setitem_indexed)(struct __pyx_memoryview_obj *, PyObject *, PyObject *);
+ PyObject *(*convert_item_to_object)(struct __pyx_memoryview_obj *, char *);
+ PyObject *(*assign_item_from_object)(struct __pyx_memoryview_obj *, char *, PyObject *);
+};
+static struct __pyx_vtabstruct_memoryview *__pyx_vtabptr_memoryview;
+
+
+/* "View.MemoryView":967
+ *
+ * @cname('__pyx_memoryviewslice')
+ * cdef class _memoryviewslice(memoryview): # <<<<<<<<<<<<<<
+ * "Internal class for passing memoryview slices to Python"
+ *
+ */
+
+struct __pyx_vtabstruct__memoryviewslice {
+ struct __pyx_vtabstruct_memoryview __pyx_base;
+};
+static struct __pyx_vtabstruct__memoryviewslice *__pyx_vtabptr__memoryviewslice;
+
+/* --- Runtime support code (head) --- */
+/* Refnanny.proto */
+#ifndef CYTHON_REFNANNY
+ #define CYTHON_REFNANNY 0
+#endif
+#if CYTHON_REFNANNY
+ typedef struct {
+ void (*INCREF)(void*, PyObject*, int);
+ void (*DECREF)(void*, PyObject*, int);
+ void (*GOTREF)(void*, PyObject*, int);
+ void (*GIVEREF)(void*, PyObject*, int);
+ void* (*SetupContext)(const char*, int, const char*);
+ void (*FinishContext)(void**);
+ } __Pyx_RefNannyAPIStruct;
+ static __Pyx_RefNannyAPIStruct *__Pyx_RefNanny = NULL;
+ static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname);
+ #define __Pyx_RefNannyDeclarations void *__pyx_refnanny = NULL;
+#ifdef WITH_THREAD
+ #define __Pyx_RefNannySetupContext(name, acquire_gil)\
+ if (acquire_gil) {\
+ PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\
+ __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__);\
+ PyGILState_Release(__pyx_gilstate_save);\
+ } else {\
+ __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__);\
+ }
+#else
+ #define __Pyx_RefNannySetupContext(name, acquire_gil)\
+ __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__)
+#endif
+ #define __Pyx_RefNannyFinishContext()\
+ __Pyx_RefNanny->FinishContext(&__pyx_refnanny)
+ #define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), __LINE__)
+ #define __Pyx_DECREF(r) __Pyx_RefNanny->DECREF(__pyx_refnanny, (PyObject *)(r), __LINE__)
+ #define __Pyx_GOTREF(r) __Pyx_RefNanny->GOTREF(__pyx_refnanny, (PyObject *)(r), __LINE__)
+ #define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), __LINE__)
+ #define __Pyx_XINCREF(r) do { if((r) != NULL) {__Pyx_INCREF(r); }} while(0)
+ #define __Pyx_XDECREF(r) do { if((r) != NULL) {__Pyx_DECREF(r); }} while(0)
+ #define __Pyx_XGOTREF(r) do { if((r) != NULL) {__Pyx_GOTREF(r); }} while(0)
+ #define __Pyx_XGIVEREF(r) do { if((r) != NULL) {__Pyx_GIVEREF(r);}} while(0)
+#else
+ #define __Pyx_RefNannyDeclarations
+ #define __Pyx_RefNannySetupContext(name, acquire_gil)
+ #define __Pyx_RefNannyFinishContext()
+ #define __Pyx_INCREF(r) Py_INCREF(r)
+ #define __Pyx_DECREF(r) Py_DECREF(r)
+ #define __Pyx_GOTREF(r)
+ #define __Pyx_GIVEREF(r)
+ #define __Pyx_XINCREF(r) Py_XINCREF(r)
+ #define __Pyx_XDECREF(r) Py_XDECREF(r)
+ #define __Pyx_XGOTREF(r)
+ #define __Pyx_XGIVEREF(r)
+#endif
+#define __Pyx_XDECREF_SET(r, v) do {\
+ PyObject *tmp = (PyObject *) r;\
+ r = v; __Pyx_XDECREF(tmp);\
+ } while (0)
+#define __Pyx_DECREF_SET(r, v) do {\
+ PyObject *tmp = (PyObject *) r;\
+ r = v; __Pyx_DECREF(tmp);\
+ } while (0)
+#define __Pyx_CLEAR(r) do { PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);} while(0)
+#define __Pyx_XCLEAR(r) do { if((r) != NULL) {PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);}} while(0)
+
+/* PyObjectGetAttrStr.proto */
+#if CYTHON_USE_TYPE_SLOTS
+static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name);
+#else
+#define __Pyx_PyObject_GetAttrStr(o,n) PyObject_GetAttr(o,n)
+#endif
+
+/* GetBuiltinName.proto */
+static PyObject *__Pyx_GetBuiltinName(PyObject *name);
+
+/* RaiseArgTupleInvalid.proto */
+static void __Pyx_RaiseArgtupleInvalid(const char* func_name, int exact,
+ Py_ssize_t num_min, Py_ssize_t num_max, Py_ssize_t num_found);
+
+/* RaiseDoubleKeywords.proto */
+static void __Pyx_RaiseDoubleKeywordsError(const char* func_name, PyObject* kw_name);
+
+/* ParseKeywords.proto */
+static int __Pyx_ParseOptionalKeywords(PyObject *kwds, PyObject **argnames[],\
+ PyObject *kwds2, PyObject *values[], Py_ssize_t num_pos_args,\
+ const char* function_name);
+
+/* MemviewSliceInit.proto */
+#define __Pyx_BUF_MAX_NDIMS %(BUF_MAX_NDIMS)d
+#define __Pyx_MEMVIEW_DIRECT 1
+#define __Pyx_MEMVIEW_PTR 2
+#define __Pyx_MEMVIEW_FULL 4
+#define __Pyx_MEMVIEW_CONTIG 8
+#define __Pyx_MEMVIEW_STRIDED 16
+#define __Pyx_MEMVIEW_FOLLOW 32
+#define __Pyx_IS_C_CONTIG 1
+#define __Pyx_IS_F_CONTIG 2
+static int __Pyx_init_memviewslice(
+ struct __pyx_memoryview_obj *memview,
+ int ndim,
+ __Pyx_memviewslice *memviewslice,
+ int memview_is_new_reference);
+static CYTHON_INLINE int __pyx_add_acquisition_count_locked(
+ __pyx_atomic_int *acquisition_count, PyThread_type_lock lock);
+static CYTHON_INLINE int __pyx_sub_acquisition_count_locked(
+ __pyx_atomic_int *acquisition_count, PyThread_type_lock lock);
+#define __pyx_get_slice_count_pointer(memview) (memview->acquisition_count_aligned_p)
+#define __pyx_get_slice_count(memview) (*__pyx_get_slice_count_pointer(memview))
+#define __PYX_INC_MEMVIEW(slice, have_gil) __Pyx_INC_MEMVIEW(slice, have_gil, __LINE__)
+#define __PYX_XDEC_MEMVIEW(slice, have_gil) __Pyx_XDEC_MEMVIEW(slice, have_gil, __LINE__)
+static CYTHON_INLINE void __Pyx_INC_MEMVIEW(__Pyx_memviewslice *, int, int);
+static CYTHON_INLINE void __Pyx_XDEC_MEMVIEW(__Pyx_memviewslice *, int, int);
+
+/* PyThreadStateGet.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_PyThreadState_declare PyThreadState *__pyx_tstate;
+#define __Pyx_PyThreadState_assign __pyx_tstate = __Pyx_PyThreadState_Current;
+#define __Pyx_PyErr_Occurred() __pyx_tstate->curexc_type
+#else
+#define __Pyx_PyThreadState_declare
+#define __Pyx_PyThreadState_assign
+#define __Pyx_PyErr_Occurred() PyErr_Occurred()
+#endif
+
+/* PyErrFetchRestore.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_PyErr_Clear() __Pyx_ErrRestore(NULL, NULL, NULL)
+#define __Pyx_ErrRestoreWithState(type, value, tb) __Pyx_ErrRestoreInState(PyThreadState_GET(), type, value, tb)
+#define __Pyx_ErrFetchWithState(type, value, tb) __Pyx_ErrFetchInState(PyThreadState_GET(), type, value, tb)
+#define __Pyx_ErrRestore(type, value, tb) __Pyx_ErrRestoreInState(__pyx_tstate, type, value, tb)
+#define __Pyx_ErrFetch(type, value, tb) __Pyx_ErrFetchInState(__pyx_tstate, type, value, tb)
+static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb);
+static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb);
+#if CYTHON_COMPILING_IN_CPYTHON
+#define __Pyx_PyErr_SetNone(exc) (Py_INCREF(exc), __Pyx_ErrRestore((exc), NULL, NULL))
+#else
+#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc)
+#endif
+#else
+#define __Pyx_PyErr_Clear() PyErr_Clear()
+#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc)
+#define __Pyx_ErrRestoreWithState(type, value, tb) PyErr_Restore(type, value, tb)
+#define __Pyx_ErrFetchWithState(type, value, tb) PyErr_Fetch(type, value, tb)
+#define __Pyx_ErrRestoreInState(tstate, type, value, tb) PyErr_Restore(type, value, tb)
+#define __Pyx_ErrFetchInState(tstate, type, value, tb) PyErr_Fetch(type, value, tb)
+#define __Pyx_ErrRestore(type, value, tb) PyErr_Restore(type, value, tb)
+#define __Pyx_ErrFetch(type, value, tb) PyErr_Fetch(type, value, tb)
+#endif
+
+/* WriteUnraisableException.proto */
+static void __Pyx_WriteUnraisable(const char *name, int clineno,
+ int lineno, const char *filename,
+ int full_traceback, int nogil);
+
+/* PyDictVersioning.proto */
+#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS
+#define __PYX_DICT_VERSION_INIT ((PY_UINT64_T) -1)
+#define __PYX_GET_DICT_VERSION(dict) (((PyDictObject*)(dict))->ma_version_tag)
+#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)\
+ (version_var) = __PYX_GET_DICT_VERSION(dict);\
+ (cache_var) = (value);
+#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) {\
+ static PY_UINT64_T __pyx_dict_version = 0;\
+ static PyObject *__pyx_dict_cached_value = NULL;\
+ if (likely(__PYX_GET_DICT_VERSION(DICT) == __pyx_dict_version)) {\
+ (VAR) = __pyx_dict_cached_value;\
+ } else {\
+ (VAR) = __pyx_dict_cached_value = (LOOKUP);\
+ __pyx_dict_version = __PYX_GET_DICT_VERSION(DICT);\
+ }\
+}
+static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj);
+static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj);
+static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version);
+#else
+#define __PYX_GET_DICT_VERSION(dict) (0)
+#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)
+#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) (VAR) = (LOOKUP);
+#endif
+
+/* None.proto */
+static CYTHON_INLINE void __Pyx_RaiseUnboundLocalError(const char *varname);
+
+/* PyCFunctionFastCall.proto */
+#if CYTHON_FAST_PYCCALL
+static CYTHON_INLINE PyObject *__Pyx_PyCFunction_FastCall(PyObject *func, PyObject **args, Py_ssize_t nargs);
+#else
+#define __Pyx_PyCFunction_FastCall(func, args, nargs) (assert(0), NULL)
+#endif
+
+/* PyFunctionFastCall.proto */
+#if CYTHON_FAST_PYCALL
+#define __Pyx_PyFunction_FastCall(func, args, nargs)\
+ __Pyx_PyFunction_FastCallDict((func), (args), (nargs), NULL)
+#if 1 || PY_VERSION_HEX < 0x030600B1
+static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs);
+#else
+#define __Pyx_PyFunction_FastCallDict(func, args, nargs, kwargs) _PyFunction_FastCallDict(func, args, nargs, kwargs)
+#endif
+#define __Pyx_BUILD_ASSERT_EXPR(cond)\
+ (sizeof(char [1 - 2*!(cond)]) - 1)
+#ifndef Py_MEMBER_SIZE
+#define Py_MEMBER_SIZE(type, member) sizeof(((type *)0)->member)
+#endif
+#if CYTHON_FAST_PYCALL
+ static size_t __pyx_pyframe_localsplus_offset = 0;
+ #include "frameobject.h"
+#if PY_VERSION_HEX >= 0x030b00a6
+ #ifndef Py_BUILD_CORE
+ #define Py_BUILD_CORE 1
+ #endif
+ #include "internal/pycore_frame.h"
+#endif
+ #define __Pxy_PyFrame_Initialize_Offsets()\
+ ((void)__Pyx_BUILD_ASSERT_EXPR(sizeof(PyFrameObject) == offsetof(PyFrameObject, f_localsplus) + Py_MEMBER_SIZE(PyFrameObject, f_localsplus)),\
+ (void)(__pyx_pyframe_localsplus_offset = ((size_t)PyFrame_Type.tp_basicsize) - Py_MEMBER_SIZE(PyFrameObject, f_localsplus)))
+ #define __Pyx_PyFrame_GetLocalsplus(frame)\
+ (assert(__pyx_pyframe_localsplus_offset), (PyObject **)(((char *)(frame)) + __pyx_pyframe_localsplus_offset))
+#endif // CYTHON_FAST_PYCALL
+#endif
+
+/* PyObjectCall.proto */
+#if CYTHON_COMPILING_IN_CPYTHON
+static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw);
+#else
+#define __Pyx_PyObject_Call(func, arg, kw) PyObject_Call(func, arg, kw)
+#endif
+
+/* PyObjectCall2Args.proto */
+static CYTHON_UNUSED PyObject* __Pyx_PyObject_Call2Args(PyObject* function, PyObject* arg1, PyObject* arg2);
+
+/* PyObjectCallMethO.proto */
+#if CYTHON_COMPILING_IN_CPYTHON
+static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg);
+#endif
+
+/* PyObjectCallOneArg.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg);
+
+/* GetModuleGlobalName.proto */
+#if CYTHON_USE_DICT_VERSIONS
+#define __Pyx_GetModuleGlobalName(var, name) do {\
+ static PY_UINT64_T __pyx_dict_version = 0;\
+ static PyObject *__pyx_dict_cached_value = NULL;\
+ (var) = (likely(__pyx_dict_version == __PYX_GET_DICT_VERSION(__pyx_d))) ?\
+ (likely(__pyx_dict_cached_value) ? __Pyx_NewRef(__pyx_dict_cached_value) : __Pyx_GetBuiltinName(name)) :\
+ __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\
+} while(0)
+#define __Pyx_GetModuleGlobalNameUncached(var, name) do {\
+ PY_UINT64_T __pyx_dict_version;\
+ PyObject *__pyx_dict_cached_value;\
+ (var) = __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\
+} while(0)
+static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value);
+#else
+#define __Pyx_GetModuleGlobalName(var, name) (var) = __Pyx__GetModuleGlobalName(name)
+#define __Pyx_GetModuleGlobalNameUncached(var, name) (var) = __Pyx__GetModuleGlobalName(name)
+static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name);
+#endif
+
+/* RaiseException.proto */
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause);
+
+/* GetTopmostException.proto */
+#if CYTHON_USE_EXC_INFO_STACK
+static _PyErr_StackItem * __Pyx_PyErr_GetTopmostException(PyThreadState *tstate);
+#endif
+
+/* SaveResetException.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_ExceptionSave(type, value, tb) __Pyx__ExceptionSave(__pyx_tstate, type, value, tb)
+static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb);
+#define __Pyx_ExceptionReset(type, value, tb) __Pyx__ExceptionReset(__pyx_tstate, type, value, tb)
+static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb);
+#else
+#define __Pyx_ExceptionSave(type, value, tb) PyErr_GetExcInfo(type, value, tb)
+#define __Pyx_ExceptionReset(type, value, tb) PyErr_SetExcInfo(type, value, tb)
+#endif
+
+/* PyErrExceptionMatches.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_PyErr_ExceptionMatches(err) __Pyx_PyErr_ExceptionMatchesInState(__pyx_tstate, err)
+static CYTHON_INLINE int __Pyx_PyErr_ExceptionMatchesInState(PyThreadState* tstate, PyObject* err);
+#else
+#define __Pyx_PyErr_ExceptionMatches(err) PyErr_ExceptionMatches(err)
+#endif
+
+/* GetException.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_GetException(type, value, tb) __Pyx__GetException(__pyx_tstate, type, value, tb)
+static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb);
+#else
+static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb);
+#endif
+
+/* ArgTypeTest.proto */
+#define __Pyx_ArgTypeTest(obj, type, none_allowed, name, exact)\
+ ((likely((Py_TYPE(obj) == type) | (none_allowed && (obj == Py_None)))) ? 1 :\
+ __Pyx__ArgTypeTest(obj, type, name, exact))
+static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact);
+
+/* IncludeStringH.proto */
+#include
+
+/* BytesEquals.proto */
+static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals);
+
+/* UnicodeEquals.proto */
+static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals);
+
+/* StrEquals.proto */
+#if PY_MAJOR_VERSION >= 3
+#define __Pyx_PyString_Equals __Pyx_PyUnicode_Equals
+#else
+#define __Pyx_PyString_Equals __Pyx_PyBytes_Equals
+#endif
+
+/* DivInt[Py_ssize_t].proto */
+static CYTHON_INLINE Py_ssize_t __Pyx_div_Py_ssize_t(Py_ssize_t, Py_ssize_t);
+
+/* UnaryNegOverflows.proto */
+#define UNARY_NEG_WOULD_OVERFLOW(x)\
+ (((x) < 0) & ((unsigned long)(x) == 0-(unsigned long)(x)))
+
+static CYTHON_UNUSED int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/
+static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *); /*proto*/
+/* GetAttr.proto */
+static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *, PyObject *);
+
+/* GetItemInt.proto */
+#define __Pyx_GetItemInt(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\
+ (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\
+ __Pyx_GetItemInt_Fast(o, (Py_ssize_t)i, is_list, wraparound, boundscheck) :\
+ (is_list ? (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL) :\
+ __Pyx_GetItemInt_Generic(o, to_py_func(i))))
+#define __Pyx_GetItemInt_List(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\
+ (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\
+ __Pyx_GetItemInt_List_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\
+ (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL))
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i,
+ int wraparound, int boundscheck);
+#define __Pyx_GetItemInt_Tuple(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\
+ (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\
+ __Pyx_GetItemInt_Tuple_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\
+ (PyErr_SetString(PyExc_IndexError, "tuple index out of range"), (PyObject*)NULL))
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i,
+ int wraparound, int boundscheck);
+static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j);
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i,
+ int is_list, int wraparound, int boundscheck);
+
+/* ObjectGetItem.proto */
+#if CYTHON_USE_TYPE_SLOTS
+static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject* key);
+#else
+#define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key)
+#endif
+
+/* decode_c_string_utf16.proto */
+static CYTHON_INLINE PyObject *__Pyx_PyUnicode_DecodeUTF16(const char *s, Py_ssize_t size, const char *errors) {
+ int byteorder = 0;
+ return PyUnicode_DecodeUTF16(s, size, errors, &byteorder);
+}
+static CYTHON_INLINE PyObject *__Pyx_PyUnicode_DecodeUTF16LE(const char *s, Py_ssize_t size, const char *errors) {
+ int byteorder = -1;
+ return PyUnicode_DecodeUTF16(s, size, errors, &byteorder);
+}
+static CYTHON_INLINE PyObject *__Pyx_PyUnicode_DecodeUTF16BE(const char *s, Py_ssize_t size, const char *errors) {
+ int byteorder = 1;
+ return PyUnicode_DecodeUTF16(s, size, errors, &byteorder);
+}
+
+/* decode_c_string.proto */
+static CYTHON_INLINE PyObject* __Pyx_decode_c_string(
+ const char* cstring, Py_ssize_t start, Py_ssize_t stop,
+ const char* encoding, const char* errors,
+ PyObject* (*decode_func)(const char *s, Py_ssize_t size, const char *errors));
+
+/* GetAttr3.proto */
+static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *, PyObject *, PyObject *);
+
+/* RaiseTooManyValuesToUnpack.proto */
+static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected);
+
+/* RaiseNeedMoreValuesToUnpack.proto */
+static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index);
+
+/* RaiseNoneIterError.proto */
+static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void);
+
+/* ExtTypeTest.proto */
+static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type);
+
+/* SwapException.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_ExceptionSwap(type, value, tb) __Pyx__ExceptionSwap(__pyx_tstate, type, value, tb)
+static CYTHON_INLINE void __Pyx__ExceptionSwap(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb);
+#else
+static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb);
+#endif
+
+/* Import.proto */
+static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level);
+
+/* FastTypeChecks.proto */
+#if CYTHON_COMPILING_IN_CPYTHON
+#define __Pyx_TypeCheck(obj, type) __Pyx_IsSubtype(Py_TYPE(obj), (PyTypeObject *)type)
+static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b);
+static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches(PyObject *err, PyObject *type);
+static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches2(PyObject *err, PyObject *type1, PyObject *type2);
+#else
+#define __Pyx_TypeCheck(obj, type) PyObject_TypeCheck(obj, (PyTypeObject *)type)
+#define __Pyx_PyErr_GivenExceptionMatches(err, type) PyErr_GivenExceptionMatches(err, type)
+#define __Pyx_PyErr_GivenExceptionMatches2(err, type1, type2) (PyErr_GivenExceptionMatches(err, type1) || PyErr_GivenExceptionMatches(err, type2))
+#endif
+#define __Pyx_PyException_Check(obj) __Pyx_TypeCheck(obj, PyExc_Exception)
+
+static CYTHON_UNUSED int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/
+/* ListCompAppend.proto */
+#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS
+static CYTHON_INLINE int __Pyx_ListComp_Append(PyObject* list, PyObject* x) {
+ PyListObject* L = (PyListObject*) list;
+ Py_ssize_t len = Py_SIZE(list);
+ if (likely(L->allocated > len)) {
+ Py_INCREF(x);
+ PyList_SET_ITEM(list, len, x);
+ __Pyx_SET_SIZE(list, len + 1);
+ return 0;
+ }
+ return PyList_Append(list, x);
+}
+#else
+#define __Pyx_ListComp_Append(L,x) PyList_Append(L,x)
+#endif
+
+/* PyIntBinop.proto */
+#if !CYTHON_COMPILING_IN_PYPY
+static PyObject* __Pyx_PyInt_AddObjC(PyObject *op1, PyObject *op2, long intval, int inplace, int zerodivision_check);
+#else
+#define __Pyx_PyInt_AddObjC(op1, op2, intval, inplace, zerodivision_check)\
+ (inplace ? PyNumber_InPlaceAdd(op1, op2) : PyNumber_Add(op1, op2))
+#endif
+
+/* ListExtend.proto */
+static CYTHON_INLINE int __Pyx_PyList_Extend(PyObject* L, PyObject* v) {
+#if CYTHON_COMPILING_IN_CPYTHON
+ PyObject* none = _PyList_Extend((PyListObject*)L, v);
+ if (unlikely(!none))
+ return -1;
+ Py_DECREF(none);
+ return 0;
+#else
+ return PyList_SetSlice(L, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, v);
+#endif
+}
+
+/* ListAppend.proto */
+#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS
+static CYTHON_INLINE int __Pyx_PyList_Append(PyObject* list, PyObject* x) {
+ PyListObject* L = (PyListObject*) list;
+ Py_ssize_t len = Py_SIZE(list);
+ if (likely(L->allocated > len) & likely(len > (L->allocated >> 1))) {
+ Py_INCREF(x);
+ PyList_SET_ITEM(list, len, x);
+ __Pyx_SET_SIZE(list, len + 1);
+ return 0;
+ }
+ return PyList_Append(list, x);
+}
+#else
+#define __Pyx_PyList_Append(L,x) PyList_Append(L,x)
+#endif
+
+/* DivInt[long].proto */
+static CYTHON_INLINE long __Pyx_div_long(long, long);
+
+/* PySequenceContains.proto */
+static CYTHON_INLINE int __Pyx_PySequence_ContainsTF(PyObject* item, PyObject* seq, int eq) {
+ int result = PySequence_Contains(seq, item);
+ return unlikely(result < 0) ? result : (result == (eq == Py_EQ));
+}
+
+/* ImportFrom.proto */
+static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name);
+
+/* HasAttr.proto */
+static CYTHON_INLINE int __Pyx_HasAttr(PyObject *, PyObject *);
+
+/* PyObject_GenericGetAttrNoDict.proto */
+#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000
+static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name);
+#else
+#define __Pyx_PyObject_GenericGetAttrNoDict PyObject_GenericGetAttr
+#endif
+
+/* PyObject_GenericGetAttr.proto */
+#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000
+static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name);
+#else
+#define __Pyx_PyObject_GenericGetAttr PyObject_GenericGetAttr
+#endif
+
+/* SetVTable.proto */
+static int __Pyx_SetVtable(PyObject *dict, void *vtable);
+
+/* PyObjectGetAttrStrNoError.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name);
+
+/* SetupReduce.proto */
+static int __Pyx_setup_reduce(PyObject* type_obj);
+
+/* TypeImport.proto */
+#ifndef __PYX_HAVE_RT_ImportType_proto
+#define __PYX_HAVE_RT_ImportType_proto
+enum __Pyx_ImportType_CheckSize {
+ __Pyx_ImportType_CheckSize_Error = 0,
+ __Pyx_ImportType_CheckSize_Warn = 1,
+ __Pyx_ImportType_CheckSize_Ignore = 2
+};
+static PyTypeObject *__Pyx_ImportType(PyObject* module, const char *module_name, const char *class_name, size_t size, enum __Pyx_ImportType_CheckSize check_size);
+#endif
+
+/* CLineInTraceback.proto */
+#ifdef CYTHON_CLINE_IN_TRACEBACK
+#define __Pyx_CLineForTraceback(tstate, c_line) (((CYTHON_CLINE_IN_TRACEBACK)) ? c_line : 0)
+#else
+static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line);
+#endif
+
+/* CodeObjectCache.proto */
+typedef struct {
+ PyCodeObject* code_object;
+ int code_line;
+} __Pyx_CodeObjectCacheEntry;
+struct __Pyx_CodeObjectCache {
+ int count;
+ int max_count;
+ __Pyx_CodeObjectCacheEntry* entries;
+};
+static struct __Pyx_CodeObjectCache __pyx_code_cache = {0,0,NULL};
+static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line);
+static PyCodeObject *__pyx_find_code_object(int code_line);
+static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object);
+
+/* AddTraceback.proto */
+static void __Pyx_AddTraceback(const char *funcname, int c_line,
+ int py_line, const char *filename);
+
+/* None.proto */
+#include
+
+#if PY_MAJOR_VERSION < 3
+ static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
+ static void __Pyx_ReleaseBuffer(Py_buffer *view);
+#else
+ #define __Pyx_GetBuffer PyObject_GetBuffer
+ #define __Pyx_ReleaseBuffer PyBuffer_Release
+#endif
+
+
+/* BufferStructDeclare.proto */
+typedef struct {
+ Py_ssize_t shape, strides, suboffsets;
+} __Pyx_Buf_DimInfo;
+typedef struct {
+ size_t refcount;
+ Py_buffer pybuffer;
+} __Pyx_Buffer;
+typedef struct {
+ __Pyx_Buffer *rcbuffer;
+ char *data;
+ __Pyx_Buf_DimInfo diminfo[8];
+} __Pyx_LocalBuf_ND;
+
+/* MemviewSliceIsContig.proto */
+static int __pyx_memviewslice_is_contig(const __Pyx_memviewslice mvs, char order, int ndim);
+
+/* OverlappingSlices.proto */
+static int __pyx_slices_overlap(__Pyx_memviewslice *slice1,
+ __Pyx_memviewslice *slice2,
+ int ndim, size_t itemsize);
+
+/* Capsule.proto */
+static CYTHON_INLINE PyObject *__pyx_capsule_create(void *p, const char *sig);
+
+/* IsLittleEndian.proto */
+static CYTHON_INLINE int __Pyx_Is_Little_Endian(void);
+
+/* BufferFormatCheck.proto */
+static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts);
+static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx,
+ __Pyx_BufFmt_StackElem* stack,
+ __Pyx_TypeInfo* type);
+
+/* TypeInfoCompare.proto */
+static int __pyx_typeinfo_cmp(__Pyx_TypeInfo *a, __Pyx_TypeInfo *b);
+
+/* MemviewSliceValidateAndInit.proto */
+static int __Pyx_ValidateAndInit_memviewslice(
+ int *axes_specs,
+ int c_or_f_flag,
+ int buf_flags,
+ int ndim,
+ __Pyx_TypeInfo *dtype,
+ __Pyx_BufFmt_StackElem stack[],
+ __Pyx_memviewslice *memviewslice,
+ PyObject *original_obj);
+
+/* ObjectToMemviewSlice.proto */
+static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_dsdsds_double(PyObject *, int writable_flag);
+
+/* GCCDiagnostics.proto */
+#if defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))
+#define __Pyx_HAS_GCC_DIAGNOSTIC
+#endif
+
+/* ObjectToMemviewSlice.proto */
+static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_dsds_double(PyObject *, int writable_flag);
+
+/* CppExceptionConversion.proto */
+#ifndef __Pyx_CppExn2PyErr
+#include
+#include
+#include
+#include
+static void __Pyx_CppExn2PyErr() {
+ try {
+ if (PyErr_Occurred())
+ ; // let the latest Python exn pass through and ignore the current one
+ else
+ throw;
+ } catch (const std::bad_alloc& exn) {
+ PyErr_SetString(PyExc_MemoryError, exn.what());
+ } catch (const std::bad_cast& exn) {
+ PyErr_SetString(PyExc_TypeError, exn.what());
+ } catch (const std::bad_typeid& exn) {
+ PyErr_SetString(PyExc_TypeError, exn.what());
+ } catch (const std::domain_error& exn) {
+ PyErr_SetString(PyExc_ValueError, exn.what());
+ } catch (const std::invalid_argument& exn) {
+ PyErr_SetString(PyExc_ValueError, exn.what());
+ } catch (const std::ios_base::failure& exn) {
+ PyErr_SetString(PyExc_IOError, exn.what());
+ } catch (const std::out_of_range& exn) {
+ PyErr_SetString(PyExc_IndexError, exn.what());
+ } catch (const std::overflow_error& exn) {
+ PyErr_SetString(PyExc_OverflowError, exn.what());
+ } catch (const std::range_error& exn) {
+ PyErr_SetString(PyExc_ArithmeticError, exn.what());
+ } catch (const std::underflow_error& exn) {
+ PyErr_SetString(PyExc_ArithmeticError, exn.what());
+ } catch (const std::exception& exn) {
+ PyErr_SetString(PyExc_RuntimeError, exn.what());
+ }
+ catch (...)
+ {
+ PyErr_SetString(PyExc_RuntimeError, "Unknown exception");
+ }
+}
+#endif
+
+/* MemviewDtypeToObject.proto */
+static CYTHON_INLINE PyObject *__pyx_memview_get_double(const char *itemp);
+static CYTHON_INLINE int __pyx_memview_set_double(const char *itemp, PyObject *obj);
+
+/* ObjectToMemviewSlice.proto */
+static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_int(PyObject *, int writable_flag);
+
+/* RealImag.proto */
+#if CYTHON_CCOMPLEX
+ #ifdef __cplusplus
+ #define __Pyx_CREAL(z) ((z).real())
+ #define __Pyx_CIMAG(z) ((z).imag())
+ #else
+ #define __Pyx_CREAL(z) (__real__(z))
+ #define __Pyx_CIMAG(z) (__imag__(z))
+ #endif
+#else
+ #define __Pyx_CREAL(z) ((z).real)
+ #define __Pyx_CIMAG(z) ((z).imag)
+#endif
+#if defined(__cplusplus) && CYTHON_CCOMPLEX\
+ && (defined(_WIN32) || defined(__clang__) || (defined(__GNUC__) && (__GNUC__ >= 5 || __GNUC__ == 4 && __GNUC_MINOR__ >= 4 )) || __cplusplus >= 201103)
+ #define __Pyx_SET_CREAL(z,x) ((z).real(x))
+ #define __Pyx_SET_CIMAG(z,y) ((z).imag(y))
+#else
+ #define __Pyx_SET_CREAL(z,x) __Pyx_CREAL(z) = (x)
+ #define __Pyx_SET_CIMAG(z,y) __Pyx_CIMAG(z) = (y)
+#endif
+
+/* Arithmetic.proto */
+#if CYTHON_CCOMPLEX
+ #define __Pyx_c_eq_float(a, b) ((a)==(b))
+ #define __Pyx_c_sum_float(a, b) ((a)+(b))
+ #define __Pyx_c_diff_float(a, b) ((a)-(b))
+ #define __Pyx_c_prod_float(a, b) ((a)*(b))
+ #define __Pyx_c_quot_float(a, b) ((a)/(b))
+ #define __Pyx_c_neg_float(a) (-(a))
+ #ifdef __cplusplus
+ #define __Pyx_c_is_zero_float(z) ((z)==(float)0)
+ #define __Pyx_c_conj_float(z) (::std::conj(z))
+ #if 1
+ #define __Pyx_c_abs_float(z) (::std::abs(z))
+ #define __Pyx_c_pow_float(a, b) (::std::pow(a, b))
+ #endif
+ #else
+ #define __Pyx_c_is_zero_float(z) ((z)==0)
+ #define __Pyx_c_conj_float(z) (conjf(z))
+ #if 1
+ #define __Pyx_c_abs_float(z) (cabsf(z))
+ #define __Pyx_c_pow_float(a, b) (cpowf(a, b))
+ #endif
+ #endif
+#else
+ static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex);
+ static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex);
+ #if 1
+ static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ #endif
+#endif
+
+/* Arithmetic.proto */
+#if CYTHON_CCOMPLEX
+ #define __Pyx_c_eq_double(a, b) ((a)==(b))
+ #define __Pyx_c_sum_double(a, b) ((a)+(b))
+ #define __Pyx_c_diff_double(a, b) ((a)-(b))
+ #define __Pyx_c_prod_double(a, b) ((a)*(b))
+ #define __Pyx_c_quot_double(a, b) ((a)/(b))
+ #define __Pyx_c_neg_double(a) (-(a))
+ #ifdef __cplusplus
+ #define __Pyx_c_is_zero_double(z) ((z)==(double)0)
+ #define __Pyx_c_conj_double(z) (::std::conj(z))
+ #if 1
+ #define __Pyx_c_abs_double(z) (::std::abs(z))
+ #define __Pyx_c_pow_double(a, b) (::std::pow(a, b))
+ #endif
+ #else
+ #define __Pyx_c_is_zero_double(z) ((z)==0)
+ #define __Pyx_c_conj_double(z) (conj(z))
+ #if 1
+ #define __Pyx_c_abs_double(z) (cabs(z))
+ #define __Pyx_c_pow_double(a, b) (cpow(a, b))
+ #endif
+ #endif
+#else
+ static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex);
+ static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex);
+ #if 1
+ static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ #endif
+#endif
+
+/* MemviewSliceCopyTemplate.proto */
+static __Pyx_memviewslice
+__pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs,
+ const char *mode, int ndim,
+ size_t sizeof_dtype, int contig_flag,
+ int dtype_is_object);
+
+/* CIntFromPy.proto */
+static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *);
+
+/* CIntToPy.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value);
+
+/* CIntToPy.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value);
+
+/* CIntFromPy.proto */
+static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *);
+
+/* CIntFromPy.proto */
+static CYTHON_INLINE char __Pyx_PyInt_As_char(PyObject *);
+
+/* CheckBinaryVersion.proto */
+static int __Pyx_check_binary_version(void);
+
+/* InitStrings.proto */
+static int __Pyx_InitStrings(__Pyx_StringTabEntry *t);
+
+static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, __Pyx_memviewslice __pyx_v_triangles); /* proto*/
+static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, __Pyx_memviewslice __pyx_v_points, int __pyx_skip_dispatch); /* proto*/
+static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *__pyx_v_self); /* proto*/
+static char *__pyx_memoryview_get_item_pointer(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index); /* proto*/
+static PyObject *__pyx_memoryview_is_slice(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj); /* proto*/
+static PyObject *__pyx_memoryview_setitem_slice_assignment(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_dst, PyObject *__pyx_v_src); /* proto*/
+static PyObject *__pyx_memoryview_setitem_slice_assign_scalar(struct __pyx_memoryview_obj *__pyx_v_self, struct __pyx_memoryview_obj *__pyx_v_dst, PyObject *__pyx_v_value); /* proto*/
+static PyObject *__pyx_memoryview_setitem_indexed(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /* proto*/
+static PyObject *__pyx_memoryview_convert_item_to_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/
+static PyObject *__pyx_memoryview_assign_item_from_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/
+static PyObject *__pyx_memoryviewslice_convert_item_to_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/
+static PyObject *__pyx_memoryviewslice_assign_item_from_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/
+
+/* Module declarations from 'cpython.buffer' */
+
+/* Module declarations from 'libc.string' */
+
+/* Module declarations from 'libc.stdio' */
+
+/* Module declarations from '__builtin__' */
+
+/* Module declarations from 'cpython.type' */
+static PyTypeObject *__pyx_ptype_7cpython_4type_type = 0;
+
+/* Module declarations from 'cpython' */
+
+/* Module declarations from 'cpython.object' */
+
+/* Module declarations from 'cpython.ref' */
+
+/* Module declarations from 'cpython.mem' */
+
+/* Module declarations from 'numpy' */
+
+/* Module declarations from 'numpy' */
+static PyTypeObject *__pyx_ptype_5numpy_dtype = 0;
+static PyTypeObject *__pyx_ptype_5numpy_flatiter = 0;
+static PyTypeObject *__pyx_ptype_5numpy_broadcast = 0;
+static PyTypeObject *__pyx_ptype_5numpy_ndarray = 0;
+static PyTypeObject *__pyx_ptype_5numpy_generic = 0;
+static PyTypeObject *__pyx_ptype_5numpy_number = 0;
+static PyTypeObject *__pyx_ptype_5numpy_integer = 0;
+static PyTypeObject *__pyx_ptype_5numpy_signedinteger = 0;
+static PyTypeObject *__pyx_ptype_5numpy_unsignedinteger = 0;
+static PyTypeObject *__pyx_ptype_5numpy_inexact = 0;
+static PyTypeObject *__pyx_ptype_5numpy_floating = 0;
+static PyTypeObject *__pyx_ptype_5numpy_complexfloating = 0;
+static PyTypeObject *__pyx_ptype_5numpy_flexible = 0;
+static PyTypeObject *__pyx_ptype_5numpy_character = 0;
+static PyTypeObject *__pyx_ptype_5numpy_ufunc = 0;
+
+/* Module declarations from 'cython.view' */
+
+/* Module declarations from 'cython' */
+
+/* Module declarations from 'libcpp.vector' */
+
+/* Module declarations from 'libc.math' */
+
+/* Module declarations from 'triangle_hash' */
+static PyTypeObject *__pyx_ptype_13triangle_hash_TriangleHash = 0;
+static PyTypeObject *__pyx_array_type = 0;
+static PyTypeObject *__pyx_MemviewEnum_type = 0;
+static PyTypeObject *__pyx_memoryview_type = 0;
+static PyTypeObject *__pyx_memoryviewslice_type = 0;
+static PyObject *generic = 0;
+static PyObject *strided = 0;
+static PyObject *indirect = 0;
+static PyObject *contiguous = 0;
+static PyObject *indirect_contiguous = 0;
+static int __pyx_memoryview_thread_locks_used;
+static PyThread_type_lock __pyx_memoryview_thread_locks[8];
+static struct __pyx_array_obj *__pyx_array_new(PyObject *, Py_ssize_t, char *, char *, char *); /*proto*/
+static void *__pyx_align_pointer(void *, size_t); /*proto*/
+static PyObject *__pyx_memoryview_new(PyObject *, int, int, __Pyx_TypeInfo *); /*proto*/
+static CYTHON_INLINE int __pyx_memoryview_check(PyObject *); /*proto*/
+static PyObject *_unellipsify(PyObject *, int); /*proto*/
+static PyObject *assert_direct_dimensions(Py_ssize_t *, int); /*proto*/
+static struct __pyx_memoryview_obj *__pyx_memview_slice(struct __pyx_memoryview_obj *, PyObject *); /*proto*/
+static int __pyx_memoryview_slice_memviewslice(__Pyx_memviewslice *, Py_ssize_t, Py_ssize_t, Py_ssize_t, int, int, int *, Py_ssize_t, Py_ssize_t, Py_ssize_t, int, int, int, int); /*proto*/
+static char *__pyx_pybuffer_index(Py_buffer *, char *, Py_ssize_t, Py_ssize_t); /*proto*/
+static int __pyx_memslice_transpose(__Pyx_memviewslice *); /*proto*/
+static PyObject *__pyx_memoryview_fromslice(__Pyx_memviewslice, int, PyObject *(*)(char *), int (*)(char *, PyObject *), int); /*proto*/
+static __Pyx_memviewslice *__pyx_memoryview_get_slice_from_memoryview(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/
+static void __pyx_memoryview_slice_copy(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/
+static PyObject *__pyx_memoryview_copy_object(struct __pyx_memoryview_obj *); /*proto*/
+static PyObject *__pyx_memoryview_copy_object_from_slice(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/
+static Py_ssize_t abs_py_ssize_t(Py_ssize_t); /*proto*/
+static char __pyx_get_best_slice_order(__Pyx_memviewslice *, int); /*proto*/
+static void _copy_strided_to_strided(char *, Py_ssize_t *, char *, Py_ssize_t *, Py_ssize_t *, Py_ssize_t *, int, size_t); /*proto*/
+static void copy_strided_to_strided(__Pyx_memviewslice *, __Pyx_memviewslice *, int, size_t); /*proto*/
+static Py_ssize_t __pyx_memoryview_slice_get_size(__Pyx_memviewslice *, int); /*proto*/
+static Py_ssize_t __pyx_fill_contig_strides_array(Py_ssize_t *, Py_ssize_t *, Py_ssize_t, int, char); /*proto*/
+static void *__pyx_memoryview_copy_data_to_temp(__Pyx_memviewslice *, __Pyx_memviewslice *, char, int); /*proto*/
+static int __pyx_memoryview_err_extents(int, Py_ssize_t, Py_ssize_t); /*proto*/
+static int __pyx_memoryview_err_dim(PyObject *, char *, int); /*proto*/
+static int __pyx_memoryview_err(PyObject *, char *); /*proto*/
+static int __pyx_memoryview_copy_contents(__Pyx_memviewslice, __Pyx_memviewslice, int, int, int); /*proto*/
+static void __pyx_memoryview_broadcast_leading(__Pyx_memviewslice *, int, int); /*proto*/
+static void __pyx_memoryview_refcount_copying(__Pyx_memviewslice *, int, int, int); /*proto*/
+static void __pyx_memoryview_refcount_objects_in_slice_with_gil(char *, Py_ssize_t *, Py_ssize_t *, int, int); /*proto*/
+static void __pyx_memoryview_refcount_objects_in_slice(char *, Py_ssize_t *, Py_ssize_t *, int, int); /*proto*/
+static void __pyx_memoryview_slice_assign_scalar(__Pyx_memviewslice *, int, size_t, void *, int); /*proto*/
+static void __pyx_memoryview__slice_assign_scalar(char *, Py_ssize_t *, Py_ssize_t *, int, size_t, void *); /*proto*/
+static PyObject *__pyx_unpickle_Enum__set_state(struct __pyx_MemviewEnum_obj *, PyObject *); /*proto*/
+static __Pyx_TypeInfo __Pyx_TypeInfo_double = { "double", NULL, sizeof(double), { 0 }, 0, 'R', 0, 0 };
+static __Pyx_TypeInfo __Pyx_TypeInfo_int = { "int", NULL, sizeof(int), { 0 }, 0, IS_UNSIGNED(int) ? 'U' : 'I', IS_UNSIGNED(int), 0 };
+#define __Pyx_MODULE_NAME "triangle_hash"
+extern int __pyx_module_is_main_triangle_hash;
+int __pyx_module_is_main_triangle_hash = 0;
+
+/* Implementation of 'triangle_hash' */
+static PyObject *__pyx_builtin_range;
+static PyObject *__pyx_builtin_TypeError;
+static PyObject *__pyx_builtin_ImportError;
+static PyObject *__pyx_builtin_ValueError;
+static PyObject *__pyx_builtin_MemoryError;
+static PyObject *__pyx_builtin_enumerate;
+static PyObject *__pyx_builtin_Ellipsis;
+static PyObject *__pyx_builtin_id;
+static PyObject *__pyx_builtin_IndexError;
+static const char __pyx_k_O[] = "O";
+static const char __pyx_k_c[] = "c";
+static const char __pyx_k_id[] = "id";
+static const char __pyx_k_np[] = "np";
+static const char __pyx_k_new[] = "__new__";
+static const char __pyx_k_obj[] = "obj";
+static const char __pyx_k_base[] = "base";
+static const char __pyx_k_dict[] = "__dict__";
+static const char __pyx_k_main[] = "__main__";
+static const char __pyx_k_mode[] = "mode";
+static const char __pyx_k_name[] = "name";
+static const char __pyx_k_ndim[] = "ndim";
+static const char __pyx_k_pack[] = "pack";
+static const char __pyx_k_size[] = "size";
+static const char __pyx_k_step[] = "step";
+static const char __pyx_k_stop[] = "stop";
+static const char __pyx_k_test[] = "__test__";
+static const char __pyx_k_ASCII[] = "ASCII";
+static const char __pyx_k_class[] = "__class__";
+static const char __pyx_k_dtype[] = "dtype";
+static const char __pyx_k_error[] = "error";
+static const char __pyx_k_flags[] = "flags";
+static const char __pyx_k_int32[] = "int32";
+static const char __pyx_k_numpy[] = "numpy";
+static const char __pyx_k_query[] = "query";
+static const char __pyx_k_range[] = "range";
+static const char __pyx_k_shape[] = "shape";
+static const char __pyx_k_start[] = "start";
+static const char __pyx_k_zeros[] = "zeros";
+static const char __pyx_k_encode[] = "encode";
+static const char __pyx_k_format[] = "format";
+static const char __pyx_k_import[] = "__import__";
+static const char __pyx_k_name_2[] = "__name__";
+static const char __pyx_k_pickle[] = "pickle";
+static const char __pyx_k_reduce[] = "__reduce__";
+static const char __pyx_k_struct[] = "struct";
+static const char __pyx_k_unpack[] = "unpack";
+static const char __pyx_k_update[] = "update";
+static const char __pyx_k_fortran[] = "fortran";
+static const char __pyx_k_memview[] = "memview";
+static const char __pyx_k_Ellipsis[] = "Ellipsis";
+static const char __pyx_k_getstate[] = "__getstate__";
+static const char __pyx_k_itemsize[] = "itemsize";
+static const char __pyx_k_pyx_type[] = "__pyx_type";
+static const char __pyx_k_setstate[] = "__setstate__";
+static const char __pyx_k_TypeError[] = "TypeError";
+static const char __pyx_k_enumerate[] = "enumerate";
+static const char __pyx_k_pyx_state[] = "__pyx_state";
+static const char __pyx_k_reduce_ex[] = "__reduce_ex__";
+static const char __pyx_k_triangles[] = "triangles";
+static const char __pyx_k_IndexError[] = "IndexError";
+static const char __pyx_k_ValueError[] = "ValueError";
+static const char __pyx_k_pyx_result[] = "__pyx_result";
+static const char __pyx_k_pyx_vtable[] = "__pyx_vtable__";
+static const char __pyx_k_resolution[] = "resolution";
+static const char __pyx_k_ImportError[] = "ImportError";
+static const char __pyx_k_MemoryError[] = "MemoryError";
+static const char __pyx_k_PickleError[] = "PickleError";
+static const char __pyx_k_TriangleHash[] = "TriangleHash";
+static const char __pyx_k_pyx_checksum[] = "__pyx_checksum";
+static const char __pyx_k_stringsource[] = "stringsource";
+static const char __pyx_k_pyx_getbuffer[] = "__pyx_getbuffer";
+static const char __pyx_k_reduce_cython[] = "__reduce_cython__";
+static const char __pyx_k_View_MemoryView[] = "View.MemoryView";
+static const char __pyx_k_allocate_buffer[] = "allocate_buffer";
+static const char __pyx_k_dtype_is_object[] = "dtype_is_object";
+static const char __pyx_k_pyx_PickleError[] = "__pyx_PickleError";
+static const char __pyx_k_setstate_cython[] = "__setstate_cython__";
+static const char __pyx_k_pyx_unpickle_Enum[] = "__pyx_unpickle_Enum";
+static const char __pyx_k_cline_in_traceback[] = "cline_in_traceback";
+static const char __pyx_k_strided_and_direct[] = "";
+static const char __pyx_k_strided_and_indirect[] = "";
+static const char __pyx_k_contiguous_and_direct[] = "";
+static const char __pyx_k_MemoryView_of_r_object[] = "";
+static const char __pyx_k_MemoryView_of_r_at_0x_x[] = "";
+static const char __pyx_k_contiguous_and_indirect[] = "";
+static const char __pyx_k_Cannot_index_with_type_s[] = "Cannot index with type '%s'";
+static const char __pyx_k_Invalid_shape_in_axis_d_d[] = "Invalid shape in axis %d: %d.";
+static const char __pyx_k_itemsize_0_for_cython_array[] = "itemsize <= 0 for cython.array";
+static const char __pyx_k_unable_to_allocate_array_data[] = "unable to allocate array data.";
+static const char __pyx_k_strided_and_direct_or_indirect[] = "";
+static const char __pyx_k_numpy_core_multiarray_failed_to[] = "numpy.core.multiarray failed to import";
+static const char __pyx_k_Buffer_view_does_not_expose_stri[] = "Buffer view does not expose strides";
+static const char __pyx_k_Can_only_create_a_buffer_that_is[] = "Can only create a buffer that is contiguous in memory.";
+static const char __pyx_k_Cannot_assign_to_read_only_memor[] = "Cannot assign to read-only memoryview";
+static const char __pyx_k_Cannot_create_writable_memory_vi[] = "Cannot create writable memory view from read-only memoryview";
+static const char __pyx_k_Empty_shape_tuple_for_cython_arr[] = "Empty shape tuple for cython.array";
+static const char __pyx_k_Incompatible_checksums_0x_x_vs_0[] = "Incompatible checksums (0x%x vs (0xb068931, 0x82a3537, 0x6ae9995) = (name))";
+static const char __pyx_k_Indirect_dimensions_not_supporte[] = "Indirect dimensions not supported";
+static const char __pyx_k_Invalid_mode_expected_c_or_fortr[] = "Invalid mode, expected 'c' or 'fortran', got %s";
+static const char __pyx_k_Out_of_bounds_on_buffer_access_a[] = "Out of bounds on buffer access (axis %d)";
+static const char __pyx_k_Unable_to_convert_item_to_object[] = "Unable to convert item to object";
+static const char __pyx_k_got_differing_extents_in_dimensi[] = "got differing extents in dimension %d (got %d and %d)";
+static const char __pyx_k_no_default___reduce___due_to_non[] = "no default __reduce__ due to non-trivial __cinit__";
+static const char __pyx_k_numpy_core_umath_failed_to_impor[] = "numpy.core.umath failed to import";
+static const char __pyx_k_unable_to_allocate_shape_and_str[] = "unable to allocate shape and strides.";
+static PyObject *__pyx_n_s_ASCII;
+static PyObject *__pyx_kp_s_Buffer_view_does_not_expose_stri;
+static PyObject *__pyx_kp_s_Can_only_create_a_buffer_that_is;
+static PyObject *__pyx_kp_s_Cannot_assign_to_read_only_memor;
+static PyObject *__pyx_kp_s_Cannot_create_writable_memory_vi;
+static PyObject *__pyx_kp_s_Cannot_index_with_type_s;
+static PyObject *__pyx_n_s_Ellipsis;
+static PyObject *__pyx_kp_s_Empty_shape_tuple_for_cython_arr;
+static PyObject *__pyx_n_s_ImportError;
+static PyObject *__pyx_kp_s_Incompatible_checksums_0x_x_vs_0;
+static PyObject *__pyx_n_s_IndexError;
+static PyObject *__pyx_kp_s_Indirect_dimensions_not_supporte;
+static PyObject *__pyx_kp_s_Invalid_mode_expected_c_or_fortr;
+static PyObject *__pyx_kp_s_Invalid_shape_in_axis_d_d;
+static PyObject *__pyx_n_s_MemoryError;
+static PyObject *__pyx_kp_s_MemoryView_of_r_at_0x_x;
+static PyObject *__pyx_kp_s_MemoryView_of_r_object;
+static PyObject *__pyx_n_b_O;
+static PyObject *__pyx_kp_s_Out_of_bounds_on_buffer_access_a;
+static PyObject *__pyx_n_s_PickleError;
+static PyObject *__pyx_n_s_TriangleHash;
+static PyObject *__pyx_n_s_TypeError;
+static PyObject *__pyx_kp_s_Unable_to_convert_item_to_object;
+static PyObject *__pyx_n_s_ValueError;
+static PyObject *__pyx_n_s_View_MemoryView;
+static PyObject *__pyx_n_s_allocate_buffer;
+static PyObject *__pyx_n_s_base;
+static PyObject *__pyx_n_s_c;
+static PyObject *__pyx_n_u_c;
+static PyObject *__pyx_n_s_class;
+static PyObject *__pyx_n_s_cline_in_traceback;
+static PyObject *__pyx_kp_s_contiguous_and_direct;
+static PyObject *__pyx_kp_s_contiguous_and_indirect;
+static PyObject *__pyx_n_s_dict;
+static PyObject *__pyx_n_s_dtype;
+static PyObject *__pyx_n_s_dtype_is_object;
+static PyObject *__pyx_n_s_encode;
+static PyObject *__pyx_n_s_enumerate;
+static PyObject *__pyx_n_s_error;
+static PyObject *__pyx_n_s_flags;
+static PyObject *__pyx_n_s_format;
+static PyObject *__pyx_n_s_fortran;
+static PyObject *__pyx_n_u_fortran;
+static PyObject *__pyx_n_s_getstate;
+static PyObject *__pyx_kp_s_got_differing_extents_in_dimensi;
+static PyObject *__pyx_n_s_id;
+static PyObject *__pyx_n_s_import;
+static PyObject *__pyx_n_s_int32;
+static PyObject *__pyx_n_s_itemsize;
+static PyObject *__pyx_kp_s_itemsize_0_for_cython_array;
+static PyObject *__pyx_n_s_main;
+static PyObject *__pyx_n_s_memview;
+static PyObject *__pyx_n_s_mode;
+static PyObject *__pyx_n_s_name;
+static PyObject *__pyx_n_s_name_2;
+static PyObject *__pyx_n_s_ndim;
+static PyObject *__pyx_n_s_new;
+static PyObject *__pyx_kp_s_no_default___reduce___due_to_non;
+static PyObject *__pyx_n_s_np;
+static PyObject *__pyx_n_s_numpy;
+static PyObject *__pyx_kp_s_numpy_core_multiarray_failed_to;
+static PyObject *__pyx_kp_s_numpy_core_umath_failed_to_impor;
+static PyObject *__pyx_n_s_obj;
+static PyObject *__pyx_n_s_pack;
+static PyObject *__pyx_n_s_pickle;
+static PyObject *__pyx_n_s_pyx_PickleError;
+static PyObject *__pyx_n_s_pyx_checksum;
+static PyObject *__pyx_n_s_pyx_getbuffer;
+static PyObject *__pyx_n_s_pyx_result;
+static PyObject *__pyx_n_s_pyx_state;
+static PyObject *__pyx_n_s_pyx_type;
+static PyObject *__pyx_n_s_pyx_unpickle_Enum;
+static PyObject *__pyx_n_s_pyx_vtable;
+static PyObject *__pyx_n_s_query;
+static PyObject *__pyx_n_s_range;
+static PyObject *__pyx_n_s_reduce;
+static PyObject *__pyx_n_s_reduce_cython;
+static PyObject *__pyx_n_s_reduce_ex;
+static PyObject *__pyx_n_s_resolution;
+static PyObject *__pyx_n_s_setstate;
+static PyObject *__pyx_n_s_setstate_cython;
+static PyObject *__pyx_n_s_shape;
+static PyObject *__pyx_n_s_size;
+static PyObject *__pyx_n_s_start;
+static PyObject *__pyx_n_s_step;
+static PyObject *__pyx_n_s_stop;
+static PyObject *__pyx_kp_s_strided_and_direct;
+static PyObject *__pyx_kp_s_strided_and_direct_or_indirect;
+static PyObject *__pyx_kp_s_strided_and_indirect;
+static PyObject *__pyx_kp_s_stringsource;
+static PyObject *__pyx_n_s_struct;
+static PyObject *__pyx_n_s_test;
+static PyObject *__pyx_n_s_triangles;
+static PyObject *__pyx_kp_s_unable_to_allocate_array_data;
+static PyObject *__pyx_kp_s_unable_to_allocate_shape_and_str;
+static PyObject *__pyx_n_s_unpack;
+static PyObject *__pyx_n_s_update;
+static PyObject *__pyx_n_s_zeros;
+static int __pyx_pf_13triangle_hash_12TriangleHash___cinit__(struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, __Pyx_memviewslice __pyx_v_triangles, int __pyx_v_resolution); /* proto */
+static PyObject *__pyx_pf_13triangle_hash_12TriangleHash_2query(struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, __Pyx_memviewslice __pyx_v_points); /* proto */
+static PyObject *__pyx_pf_13triangle_hash_12TriangleHash_4__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_13triangle_hash_12TriangleHash_6__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */
+static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, PyObject *__pyx_v_format, PyObject *__pyx_v_mode, int __pyx_v_allocate_buffer); /* proto */
+static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(struct __pyx_array_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */
+static void __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(struct __pyx_array_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_5array_7memview___get__(struct __pyx_array_obj *__pyx_v_self); /* proto */
+static Py_ssize_t __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(struct __pyx_array_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_attr); /* proto */
+static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item); /* proto */
+static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value); /* proto */
+static PyObject *__pyx_pf___pyx_array___reduce_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf___pyx_array_2__setstate_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */
+static int __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v_name); /* proto */
+static PyObject *__pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(struct __pyx_MemviewEnum_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf___pyx_MemviewEnum___reduce_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf___pyx_MemviewEnum_2__setstate_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v___pyx_state); /* proto */
+static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj, int __pyx_v_flags, int __pyx_v_dtype_is_object); /* proto */
+static void __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index); /* proto */
+static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /* proto */
+static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(struct __pyx_memoryview_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static Py_ssize_t __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf___pyx_memoryview___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf___pyx_memoryview_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */
+static void __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView_16_memoryviewslice_4base___get__(struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf___pyx_memoryviewslice___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */
+static PyObject *__pyx_pf___pyx_memoryviewslice_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */
+static PyObject *__pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state); /* proto */
+static PyObject *__pyx_tp_new_13triangle_hash_TriangleHash(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
+static PyObject *__pyx_tp_new_array(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
+static PyObject *__pyx_tp_new_Enum(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
+static PyObject *__pyx_tp_new_memoryview(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
+static PyObject *__pyx_tp_new__memoryviewslice(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
+static PyObject *__pyx_int_0;
+static PyObject *__pyx_int_1;
+static PyObject *__pyx_int_112105877;
+static PyObject *__pyx_int_136983863;
+static PyObject *__pyx_int_184977713;
+static PyObject *__pyx_int_neg_1;
+static PyObject *__pyx_tuple_;
+static PyObject *__pyx_tuple__2;
+static PyObject *__pyx_tuple__3;
+static PyObject *__pyx_tuple__4;
+static PyObject *__pyx_tuple__5;
+static PyObject *__pyx_tuple__6;
+static PyObject *__pyx_tuple__7;
+static PyObject *__pyx_tuple__8;
+static PyObject *__pyx_tuple__9;
+static PyObject *__pyx_slice__19;
+static PyObject *__pyx_tuple__10;
+static PyObject *__pyx_tuple__11;
+static PyObject *__pyx_tuple__12;
+static PyObject *__pyx_tuple__13;
+static PyObject *__pyx_tuple__14;
+static PyObject *__pyx_tuple__15;
+static PyObject *__pyx_tuple__16;
+static PyObject *__pyx_tuple__17;
+static PyObject *__pyx_tuple__18;
+static PyObject *__pyx_tuple__20;
+static PyObject *__pyx_tuple__21;
+static PyObject *__pyx_tuple__22;
+static PyObject *__pyx_tuple__23;
+static PyObject *__pyx_tuple__24;
+static PyObject *__pyx_tuple__25;
+static PyObject *__pyx_tuple__26;
+static PyObject *__pyx_tuple__27;
+static PyObject *__pyx_tuple__28;
+static PyObject *__pyx_tuple__29;
+static PyObject *__pyx_codeobj__30;
+/* Late includes */
+
+/* "triangle_hash.pyx":13
+ * cdef int resolution
+ *
+ * def __cinit__(self, double[:, :, :] triangles, int resolution): # <<<<<<<<<<<<<<
+ * self.spatial_hash.resize(resolution * resolution)
+ * self.resolution = resolution
+ */
+
+/* Python wrapper */
+static int __pyx_pw_13triangle_hash_12TriangleHash_1__cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
+static int __pyx_pw_13triangle_hash_12TriangleHash_1__cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
+ __Pyx_memviewslice __pyx_v_triangles = { 0, 0, { 0 }, { 0 }, { 0 } };
+ int __pyx_v_resolution;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ int __pyx_r;
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__cinit__ (wrapper)", 0);
+ {
+ static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_triangles,&__pyx_n_s_resolution,0};
+ PyObject* values[2] = {0,0};
+ if (unlikely(__pyx_kwds)) {
+ Py_ssize_t kw_args;
+ const Py_ssize_t pos_args = PyTuple_GET_SIZE(__pyx_args);
+ switch (pos_args) {
+ case 2: values[1] = PyTuple_GET_ITEM(__pyx_args, 1);
+ CYTHON_FALLTHROUGH;
+ case 1: values[0] = PyTuple_GET_ITEM(__pyx_args, 0);
+ CYTHON_FALLTHROUGH;
+ case 0: break;
+ default: goto __pyx_L5_argtuple_error;
+ }
+ kw_args = PyDict_Size(__pyx_kwds);
+ switch (pos_args) {
+ case 0:
+ if (likely((values[0] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_triangles)) != 0)) kw_args--;
+ else goto __pyx_L5_argtuple_error;
+ CYTHON_FALLTHROUGH;
+ case 1:
+ if (likely((values[1] = __Pyx_PyDict_GetItemStr(__pyx_kwds, __pyx_n_s_resolution)) != 0)) kw_args--;
+ else {
+ __Pyx_RaiseArgtupleInvalid("__cinit__", 1, 2, 2, 1); __PYX_ERR(0, 13, __pyx_L3_error)
+ }
+ }
+ if (unlikely(kw_args > 0)) {
+ if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "__cinit__") < 0)) __PYX_ERR(0, 13, __pyx_L3_error)
+ }
+ } else if (PyTuple_GET_SIZE(__pyx_args) != 2) {
+ goto __pyx_L5_argtuple_error;
+ } else {
+ values[0] = PyTuple_GET_ITEM(__pyx_args, 0);
+ values[1] = PyTuple_GET_ITEM(__pyx_args, 1);
+ }
+ __pyx_v_triangles = __Pyx_PyObject_to_MemoryviewSlice_dsdsds_double(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_triangles.memview)) __PYX_ERR(0, 13, __pyx_L3_error)
+ __pyx_v_resolution = __Pyx_PyInt_As_int(values[1]); if (unlikely((__pyx_v_resolution == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 13, __pyx_L3_error)
+ }
+ goto __pyx_L4_argument_unpacking_done;
+ __pyx_L5_argtuple_error:;
+ __Pyx_RaiseArgtupleInvalid("__cinit__", 1, 2, 2, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 13, __pyx_L3_error)
+ __pyx_L3_error:;
+ __Pyx_AddTraceback("triangle_hash.TriangleHash.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __Pyx_RefNannyFinishContext();
+ return -1;
+ __pyx_L4_argument_unpacking_done:;
+ __pyx_r = __pyx_pf_13triangle_hash_12TriangleHash___cinit__(((struct __pyx_obj_13triangle_hash_TriangleHash *)__pyx_v_self), __pyx_v_triangles, __pyx_v_resolution);
+
+ /* function exit code */
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+static int __pyx_pf_13triangle_hash_12TriangleHash___cinit__(struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, __Pyx_memviewslice __pyx_v_triangles, int __pyx_v_resolution) {
+ int __pyx_r;
+ __Pyx_RefNannyDeclarations
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("__cinit__", 0);
+
+ /* "triangle_hash.pyx":14
+ *
+ * def __cinit__(self, double[:, :, :] triangles, int resolution):
+ * self.spatial_hash.resize(resolution * resolution) # <<<<<<<<<<<<<<
+ * self.resolution = resolution
+ * self._build_hash(triangles)
+ */
+ try {
+ __pyx_v_self->spatial_hash.resize((__pyx_v_resolution * __pyx_v_resolution));
+ } catch(...) {
+ __Pyx_CppExn2PyErr();
+ __PYX_ERR(0, 14, __pyx_L1_error)
+ }
+
+ /* "triangle_hash.pyx":15
+ * def __cinit__(self, double[:, :, :] triangles, int resolution):
+ * self.spatial_hash.resize(resolution * resolution)
+ * self.resolution = resolution # <<<<<<<<<<<<<<
+ * self._build_hash(triangles)
+ *
+ */
+ __pyx_v_self->resolution = __pyx_v_resolution;
+
+ /* "triangle_hash.pyx":16
+ * self.spatial_hash.resize(resolution * resolution)
+ * self.resolution = resolution
+ * self._build_hash(triangles) # <<<<<<<<<<<<<<
+ *
+ * @cython.boundscheck(False) # Deactivate bounds checking
+ */
+ (void)(((struct __pyx_vtabstruct_13triangle_hash_TriangleHash *)__pyx_v_self->__pyx_vtab)->_build_hash(__pyx_v_self, __pyx_v_triangles));
+
+ /* "triangle_hash.pyx":13
+ * cdef int resolution
+ *
+ * def __cinit__(self, double[:, :, :] triangles, int resolution): # <<<<<<<<<<<<<<
+ * self.spatial_hash.resize(resolution * resolution)
+ * self.resolution = resolution
+ */
+
+ /* function exit code */
+ __pyx_r = 0;
+ goto __pyx_L0;
+ __pyx_L1_error:;
+ __Pyx_AddTraceback("triangle_hash.TriangleHash.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = -1;
+ __pyx_L0:;
+ __PYX_XDEC_MEMVIEW(&__pyx_v_triangles, 1);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "triangle_hash.pyx":20
+ * @cython.boundscheck(False) # Deactivate bounds checking
+ * @cython.wraparound(False) # Deactivate negative indexing.
+ * cdef int _build_hash(self, double[:, :, :] triangles): # <<<<<<<<<<<<<<
+ * assert(triangles.shape[1] == 3)
+ * assert(triangles.shape[2] == 2)
+ */
+
+static int __pyx_f_13triangle_hash_12TriangleHash__build_hash(struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, __Pyx_memviewslice __pyx_v_triangles) {
+ int __pyx_v_n_tri;
+ int __pyx_v_bbox_min[2];
+ int __pyx_v_bbox_max[2];
+ int __pyx_v_i_tri;
+ int __pyx_v_j;
+ int __pyx_v_x;
+ int __pyx_v_y;
+ int __pyx_v_spatial_idx;
+ int __pyx_r;
+ __Pyx_RefNannyDeclarations
+ int __pyx_t_1;
+ int __pyx_t_2;
+ int __pyx_t_3;
+ int __pyx_t_4;
+ Py_ssize_t __pyx_t_5;
+ Py_ssize_t __pyx_t_6;
+ Py_ssize_t __pyx_t_7;
+ double __pyx_t_8;
+ double __pyx_t_9;
+ double __pyx_t_10;
+ double __pyx_t_11;
+ long __pyx_t_12;
+ long __pyx_t_13;
+ int __pyx_t_14;
+ long __pyx_t_15;
+ long __pyx_t_16;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("_build_hash", 0);
+
+ /* "triangle_hash.pyx":21
+ * @cython.wraparound(False) # Deactivate negative indexing.
+ * cdef int _build_hash(self, double[:, :, :] triangles):
+ * assert(triangles.shape[1] == 3) # <<<<<<<<<<<<<<
+ * assert(triangles.shape[2] == 2)
+ *
+ */
+ #ifndef CYTHON_WITHOUT_ASSERTIONS
+ if (unlikely(!Py_OptimizeFlag)) {
+ if (unlikely(!(((__pyx_v_triangles.shape[1]) == 3) != 0))) {
+ PyErr_SetNone(PyExc_AssertionError);
+ __PYX_ERR(0, 21, __pyx_L1_error)
+ }
+ }
+ #endif
+
+ /* "triangle_hash.pyx":22
+ * cdef int _build_hash(self, double[:, :, :] triangles):
+ * assert(triangles.shape[1] == 3)
+ * assert(triangles.shape[2] == 2) # <<<<<<<<<<<<<<
+ *
+ * cdef int n_tri = triangles.shape[0]
+ */
+ #ifndef CYTHON_WITHOUT_ASSERTIONS
+ if (unlikely(!Py_OptimizeFlag)) {
+ if (unlikely(!(((__pyx_v_triangles.shape[2]) == 2) != 0))) {
+ PyErr_SetNone(PyExc_AssertionError);
+ __PYX_ERR(0, 22, __pyx_L1_error)
+ }
+ }
+ #endif
+
+ /* "triangle_hash.pyx":24
+ * assert(triangles.shape[2] == 2)
+ *
+ * cdef int n_tri = triangles.shape[0] # <<<<<<<<<<<<<<
+ * cdef int bbox_min[2]
+ * cdef int bbox_max[2]
+ */
+ __pyx_v_n_tri = (__pyx_v_triangles.shape[0]);
+
+ /* "triangle_hash.pyx":31
+ * cdef int spatial_idx
+ *
+ * for i_tri in range(n_tri): # <<<<<<<<<<<<<<
+ * # Compute bounding box
+ * for j in range(2):
+ */
+ __pyx_t_1 = __pyx_v_n_tri;
+ __pyx_t_2 = __pyx_t_1;
+ for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) {
+ __pyx_v_i_tri = __pyx_t_3;
+
+ /* "triangle_hash.pyx":33
+ * for i_tri in range(n_tri):
+ * # Compute bounding box
+ * for j in range(2): # <<<<<<<<<<<<<<
+ * bbox_min[j] = min(
+ * triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
+ */
+ for (__pyx_t_4 = 0; __pyx_t_4 < 2; __pyx_t_4+=1) {
+ __pyx_v_j = __pyx_t_4;
+
+ /* "triangle_hash.pyx":35
+ * for j in range(2):
+ * bbox_min[j] = min(
+ * triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] # <<<<<<<<<<<<<<
+ * )
+ * bbox_max[j] = max(
+ */
+ __pyx_t_5 = __pyx_v_i_tri;
+ __pyx_t_6 = 1;
+ __pyx_t_7 = __pyx_v_j;
+ __pyx_t_8 = (*((double *) ( /* dim=2 */ (( /* dim=1 */ (( /* dim=0 */ (__pyx_v_triangles.data + __pyx_t_5 * __pyx_v_triangles.strides[0]) ) + __pyx_t_6 * __pyx_v_triangles.strides[1]) ) + __pyx_t_7 * __pyx_v_triangles.strides[2]) )));
+ __pyx_t_7 = __pyx_v_i_tri;
+ __pyx_t_6 = 2;
+ __pyx_t_5 = __pyx_v_j;
+ __pyx_t_9 = (*((double *) ( /* dim=2 */ (( /* dim=1 */ (( /* dim=0 */ (__pyx_v_triangles.data + __pyx_t_7 * __pyx_v_triangles.strides[0]) ) + __pyx_t_6 * __pyx_v_triangles.strides[1]) ) + __pyx_t_5 * __pyx_v_triangles.strides[2]) )));
+ __pyx_t_5 = __pyx_v_i_tri;
+ __pyx_t_6 = 0;
+ __pyx_t_7 = __pyx_v_j;
+ __pyx_t_10 = (*((double *) ( /* dim=2 */ (( /* dim=1 */ (( /* dim=0 */ (__pyx_v_triangles.data + __pyx_t_5 * __pyx_v_triangles.strides[0]) ) + __pyx_t_6 * __pyx_v_triangles.strides[1]) ) + __pyx_t_7 * __pyx_v_triangles.strides[2]) )));
+ if (((__pyx_t_8 < __pyx_t_10) != 0)) {
+ __pyx_t_11 = __pyx_t_8;
+ } else {
+ __pyx_t_11 = __pyx_t_10;
+ }
+ __pyx_t_10 = __pyx_t_11;
+ if (((__pyx_t_9 < __pyx_t_10) != 0)) {
+ __pyx_t_11 = __pyx_t_9;
+ } else {
+ __pyx_t_11 = __pyx_t_10;
+ }
+
+ /* "triangle_hash.pyx":34
+ * # Compute bounding box
+ * for j in range(2):
+ * bbox_min[j] = min( # <<<<<<<<<<<<<<
+ * triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
+ * )
+ */
+ (__pyx_v_bbox_min[__pyx_v_j]) = ((int)__pyx_t_11);
+
+ /* "triangle_hash.pyx":38
+ * )
+ * bbox_max[j] = max(
+ * triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] # <<<<<<<<<<<<<<
+ * )
+ * bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1)
+ */
+ __pyx_t_7 = __pyx_v_i_tri;
+ __pyx_t_6 = 1;
+ __pyx_t_5 = __pyx_v_j;
+ __pyx_t_11 = (*((double *) ( /* dim=2 */ (( /* dim=1 */ (( /* dim=0 */ (__pyx_v_triangles.data + __pyx_t_7 * __pyx_v_triangles.strides[0]) ) + __pyx_t_6 * __pyx_v_triangles.strides[1]) ) + __pyx_t_5 * __pyx_v_triangles.strides[2]) )));
+ __pyx_t_5 = __pyx_v_i_tri;
+ __pyx_t_6 = 2;
+ __pyx_t_7 = __pyx_v_j;
+ __pyx_t_8 = (*((double *) ( /* dim=2 */ (( /* dim=1 */ (( /* dim=0 */ (__pyx_v_triangles.data + __pyx_t_5 * __pyx_v_triangles.strides[0]) ) + __pyx_t_6 * __pyx_v_triangles.strides[1]) ) + __pyx_t_7 * __pyx_v_triangles.strides[2]) )));
+ __pyx_t_7 = __pyx_v_i_tri;
+ __pyx_t_6 = 0;
+ __pyx_t_5 = __pyx_v_j;
+ __pyx_t_9 = (*((double *) ( /* dim=2 */ (( /* dim=1 */ (( /* dim=0 */ (__pyx_v_triangles.data + __pyx_t_7 * __pyx_v_triangles.strides[0]) ) + __pyx_t_6 * __pyx_v_triangles.strides[1]) ) + __pyx_t_5 * __pyx_v_triangles.strides[2]) )));
+ if (((__pyx_t_11 > __pyx_t_9) != 0)) {
+ __pyx_t_10 = __pyx_t_11;
+ } else {
+ __pyx_t_10 = __pyx_t_9;
+ }
+ __pyx_t_9 = __pyx_t_10;
+ if (((__pyx_t_8 > __pyx_t_9) != 0)) {
+ __pyx_t_10 = __pyx_t_8;
+ } else {
+ __pyx_t_10 = __pyx_t_9;
+ }
+
+ /* "triangle_hash.pyx":37
+ * triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
+ * )
+ * bbox_max[j] = max( # <<<<<<<<<<<<<<
+ * triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
+ * )
+ */
+ (__pyx_v_bbox_max[__pyx_v_j]) = ((int)__pyx_t_10);
+
+ /* "triangle_hash.pyx":40
+ * triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j]
+ * )
+ * bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1) # <<<<<<<<<<<<<<
+ * bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1)
+ *
+ */
+ __pyx_t_12 = (__pyx_v_self->resolution - 1);
+ __pyx_t_13 = 0;
+ __pyx_t_14 = (__pyx_v_bbox_min[__pyx_v_j]);
+ if (((__pyx_t_13 > __pyx_t_14) != 0)) {
+ __pyx_t_15 = __pyx_t_13;
+ } else {
+ __pyx_t_15 = __pyx_t_14;
+ }
+ __pyx_t_13 = __pyx_t_15;
+ if (((__pyx_t_12 < __pyx_t_13) != 0)) {
+ __pyx_t_15 = __pyx_t_12;
+ } else {
+ __pyx_t_15 = __pyx_t_13;
+ }
+ (__pyx_v_bbox_min[__pyx_v_j]) = __pyx_t_15;
+
+ /* "triangle_hash.pyx":41
+ * )
+ * bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1)
+ * bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1) # <<<<<<<<<<<<<<
+ *
+ * # Find all voxels where bounding box intersects
+ */
+ __pyx_t_15 = (__pyx_v_self->resolution - 1);
+ __pyx_t_12 = 0;
+ __pyx_t_14 = (__pyx_v_bbox_max[__pyx_v_j]);
+ if (((__pyx_t_12 > __pyx_t_14) != 0)) {
+ __pyx_t_13 = __pyx_t_12;
+ } else {
+ __pyx_t_13 = __pyx_t_14;
+ }
+ __pyx_t_12 = __pyx_t_13;
+ if (((__pyx_t_15 < __pyx_t_12) != 0)) {
+ __pyx_t_13 = __pyx_t_15;
+ } else {
+ __pyx_t_13 = __pyx_t_12;
+ }
+ (__pyx_v_bbox_max[__pyx_v_j]) = __pyx_t_13;
+ }
+
+ /* "triangle_hash.pyx":44
+ *
+ * # Find all voxels where bounding box intersects
+ * for x in range(bbox_min[0], bbox_max[0] + 1): # <<<<<<<<<<<<<<
+ * for y in range(bbox_min[1], bbox_max[1] + 1):
+ * spatial_idx = self.resolution * x + y
+ */
+ __pyx_t_13 = ((__pyx_v_bbox_max[0]) + 1);
+ __pyx_t_15 = __pyx_t_13;
+ for (__pyx_t_4 = (__pyx_v_bbox_min[0]); __pyx_t_4 < __pyx_t_15; __pyx_t_4+=1) {
+ __pyx_v_x = __pyx_t_4;
+
+ /* "triangle_hash.pyx":45
+ * # Find all voxels where bounding box intersects
+ * for x in range(bbox_min[0], bbox_max[0] + 1):
+ * for y in range(bbox_min[1], bbox_max[1] + 1): # <<<<<<<<<<<<<<
+ * spatial_idx = self.resolution * x + y
+ * self.spatial_hash[spatial_idx].push_back(i_tri)
+ */
+ __pyx_t_12 = ((__pyx_v_bbox_max[1]) + 1);
+ __pyx_t_16 = __pyx_t_12;
+ for (__pyx_t_14 = (__pyx_v_bbox_min[1]); __pyx_t_14 < __pyx_t_16; __pyx_t_14+=1) {
+ __pyx_v_y = __pyx_t_14;
+
+ /* "triangle_hash.pyx":46
+ * for x in range(bbox_min[0], bbox_max[0] + 1):
+ * for y in range(bbox_min[1], bbox_max[1] + 1):
+ * spatial_idx = self.resolution * x + y # <<<<<<<<<<<<<<
+ * self.spatial_hash[spatial_idx].push_back(i_tri)
+ *
+ */
+ __pyx_v_spatial_idx = ((__pyx_v_self->resolution * __pyx_v_x) + __pyx_v_y);
+
+ /* "triangle_hash.pyx":47
+ * for y in range(bbox_min[1], bbox_max[1] + 1):
+ * spatial_idx = self.resolution * x + y
+ * self.spatial_hash[spatial_idx].push_back(i_tri) # <<<<<<<<<<<<<<
+ *
+ * @cython.boundscheck(False) # Deactivate bounds checking
+ */
+ try {
+ (__pyx_v_self->spatial_hash[__pyx_v_spatial_idx]).push_back(__pyx_v_i_tri);
+ } catch(...) {
+ __Pyx_CppExn2PyErr();
+ __PYX_ERR(0, 47, __pyx_L1_error)
+ }
+ }
+ }
+ }
+
+ /* "triangle_hash.pyx":20
+ * @cython.boundscheck(False) # Deactivate bounds checking
+ * @cython.wraparound(False) # Deactivate negative indexing.
+ * cdef int _build_hash(self, double[:, :, :] triangles): # <<<<<<<<<<<<<<
+ * assert(triangles.shape[1] == 3)
+ * assert(triangles.shape[2] == 2)
+ */
+
+ /* function exit code */
+ __pyx_r = 0;
+ goto __pyx_L0;
+ __pyx_L1_error:;
+ __Pyx_WriteUnraisable("triangle_hash.TriangleHash._build_hash", __pyx_clineno, __pyx_lineno, __pyx_filename, 1, 0);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "triangle_hash.pyx":51
+ * @cython.boundscheck(False) # Deactivate bounds checking
+ * @cython.wraparound(False) # Deactivate negative indexing.
+ * cpdef query(self, double[:, :] points): # <<<<<<<<<<<<<<
+ * assert(points.shape[1] == 2)
+ * cdef int n_points = points.shape[0]
+ */
+
+static PyObject *__pyx_pw_13triangle_hash_12TriangleHash_3query(PyObject *__pyx_v_self, PyObject *__pyx_arg_points); /*proto*/
+static PyObject *__pyx_f_13triangle_hash_12TriangleHash_query(struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, __Pyx_memviewslice __pyx_v_points, int __pyx_skip_dispatch) {
+ int __pyx_v_n_points;
+ std::vector __pyx_v_points_indices;
+ std::vector __pyx_v_tri_indices;
+ int __pyx_v_i_point;
+ int __pyx_v_k;
+ int __pyx_v_x;
+ int __pyx_v_y;
+ int __pyx_v_spatial_idx;
+ int __pyx_v_i_tri;
+ PyObject *__pyx_v_points_indices_np = NULL;
+ PyObject *__pyx_v_tri_indices_np = NULL;
+ __Pyx_memviewslice __pyx_v_points_indices_view = { 0, 0, { 0 }, { 0 }, { 0 } };
+ __Pyx_memviewslice __pyx_v_tri_indices_view = { 0, 0, { 0 }, { 0 }, { 0 } };
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ PyObject *__pyx_t_2 = NULL;
+ PyObject *__pyx_t_3 = NULL;
+ PyObject *__pyx_t_4 = NULL;
+ PyObject *__pyx_t_5 = NULL;
+ int __pyx_t_6;
+ int __pyx_t_7;
+ int __pyx_t_8;
+ Py_ssize_t __pyx_t_9;
+ Py_ssize_t __pyx_t_10;
+ int __pyx_t_11;
+ int __pyx_t_12;
+ int __pyx_t_13;
+ std::vector ::iterator __pyx_t_14;
+ std::vector *__pyx_t_15;
+ int __pyx_t_16;
+ __Pyx_memviewslice __pyx_t_17 = { 0, 0, { 0 }, { 0 }, { 0 } };
+ std::vector ::size_type __pyx_t_18;
+ std::vector ::size_type __pyx_t_19;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("query", 0);
+ /* Check if called by wrapper */
+ if (unlikely(__pyx_skip_dispatch)) ;
+ /* Check if overridden in Python */
+ else if (unlikely((Py_TYPE(((PyObject *)__pyx_v_self))->tp_dictoffset != 0) || (Py_TYPE(((PyObject *)__pyx_v_self))->tp_flags & (Py_TPFLAGS_IS_ABSTRACT | Py_TPFLAGS_HEAPTYPE)))) {
+ #if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
+ static PY_UINT64_T __pyx_tp_dict_version = __PYX_DICT_VERSION_INIT, __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
+ if (unlikely(!__Pyx_object_dict_version_matches(((PyObject *)__pyx_v_self), __pyx_tp_dict_version, __pyx_obj_dict_version))) {
+ PY_UINT64_T __pyx_type_dict_guard = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
+ #endif
+ __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_query); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 51, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ if (!PyCFunction_Check(__pyx_t_1) || (PyCFunction_GET_FUNCTION(__pyx_t_1) != (PyCFunction)(void*)__pyx_pw_13triangle_hash_12TriangleHash_3query)) {
+ __Pyx_XDECREF(__pyx_r);
+ if (unlikely(!__pyx_v_points.memview)) { __Pyx_RaiseUnboundLocalError("points"); __PYX_ERR(0, 51, __pyx_L1_error) }
+ __pyx_t_3 = __pyx_memoryview_fromslice(__pyx_v_points, 2, (PyObject *(*)(char *)) __pyx_memview_get_double, (int (*)(char *, PyObject *)) __pyx_memview_set_double, 0);; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 51, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_INCREF(__pyx_t_1);
+ __pyx_t_4 = __pyx_t_1; __pyx_t_5 = NULL;
+ if (CYTHON_UNPACK_METHODS && unlikely(PyMethod_Check(__pyx_t_4))) {
+ __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_4);
+ if (likely(__pyx_t_5)) {
+ PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_4);
+ __Pyx_INCREF(__pyx_t_5);
+ __Pyx_INCREF(function);
+ __Pyx_DECREF_SET(__pyx_t_4, function);
+ }
+ }
+ __pyx_t_2 = (__pyx_t_5) ? __Pyx_PyObject_Call2Args(__pyx_t_4, __pyx_t_5, __pyx_t_3) : __Pyx_PyObject_CallOneArg(__pyx_t_4, __pyx_t_3);
+ __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0;
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 51, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
+ __pyx_r = __pyx_t_2;
+ __pyx_t_2 = 0;
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ goto __pyx_L0;
+ }
+ #if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
+ __pyx_tp_dict_version = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
+ __pyx_obj_dict_version = __Pyx_get_object_dict_version(((PyObject *)__pyx_v_self));
+ if (unlikely(__pyx_type_dict_guard != __pyx_tp_dict_version)) {
+ __pyx_tp_dict_version = __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
+ }
+ #endif
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ #if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
+ }
+ #endif
+ }
+
+ /* "triangle_hash.pyx":52
+ * @cython.wraparound(False) # Deactivate negative indexing.
+ * cpdef query(self, double[:, :] points):
+ * assert(points.shape[1] == 2) # <<<<<<<<<<<<<<
+ * cdef int n_points = points.shape[0]
+ *
+ */
+ #ifndef CYTHON_WITHOUT_ASSERTIONS
+ if (unlikely(!Py_OptimizeFlag)) {
+ if (unlikely(!(((__pyx_v_points.shape[1]) == 2) != 0))) {
+ PyErr_SetNone(PyExc_AssertionError);
+ __PYX_ERR(0, 52, __pyx_L1_error)
+ }
+ }
+ #endif
+
+ /* "triangle_hash.pyx":53
+ * cpdef query(self, double[:, :] points):
+ * assert(points.shape[1] == 2)
+ * cdef int n_points = points.shape[0] # <<<<<<<<<<<<<<
+ *
+ * cdef vector[int] points_indices
+ */
+ __pyx_v_n_points = (__pyx_v_points.shape[0]);
+
+ /* "triangle_hash.pyx":63
+ * cdef int spatial_idx
+ *
+ * for i_point in range(n_points): # <<<<<<<<<<<<<<
+ * x = int(points[i_point, 0])
+ * y = int(points[i_point, 1])
+ */
+ __pyx_t_6 = __pyx_v_n_points;
+ __pyx_t_7 = __pyx_t_6;
+ for (__pyx_t_8 = 0; __pyx_t_8 < __pyx_t_7; __pyx_t_8+=1) {
+ __pyx_v_i_point = __pyx_t_8;
+
+ /* "triangle_hash.pyx":64
+ *
+ * for i_point in range(n_points):
+ * x = int(points[i_point, 0]) # <<<<<<<<<<<<<<
+ * y = int(points[i_point, 1])
+ * if not (0 <= x < self.resolution and 0 <= y < self.resolution):
+ */
+ __pyx_t_9 = __pyx_v_i_point;
+ __pyx_t_10 = 0;
+ __pyx_v_x = ((int)(*((double *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_points.data + __pyx_t_9 * __pyx_v_points.strides[0]) ) + __pyx_t_10 * __pyx_v_points.strides[1]) ))));
+
+ /* "triangle_hash.pyx":65
+ * for i_point in range(n_points):
+ * x = int(points[i_point, 0])
+ * y = int(points[i_point, 1]) # <<<<<<<<<<<<<<
+ * if not (0 <= x < self.resolution and 0 <= y < self.resolution):
+ * continue
+ */
+ __pyx_t_10 = __pyx_v_i_point;
+ __pyx_t_9 = 1;
+ __pyx_v_y = ((int)(*((double *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_points.data + __pyx_t_10 * __pyx_v_points.strides[0]) ) + __pyx_t_9 * __pyx_v_points.strides[1]) ))));
+
+ /* "triangle_hash.pyx":66
+ * x = int(points[i_point, 0])
+ * y = int(points[i_point, 1])
+ * if not (0 <= x < self.resolution and 0 <= y < self.resolution): # <<<<<<<<<<<<<<
+ * continue
+ *
+ */
+ __pyx_t_12 = (0 <= __pyx_v_x);
+ if (__pyx_t_12) {
+ __pyx_t_12 = (__pyx_v_x < __pyx_v_self->resolution);
+ }
+ __pyx_t_13 = (__pyx_t_12 != 0);
+ if (__pyx_t_13) {
+ } else {
+ __pyx_t_11 = __pyx_t_13;
+ goto __pyx_L6_bool_binop_done;
+ }
+ __pyx_t_13 = (0 <= __pyx_v_y);
+ if (__pyx_t_13) {
+ __pyx_t_13 = (__pyx_v_y < __pyx_v_self->resolution);
+ }
+ __pyx_t_12 = (__pyx_t_13 != 0);
+ __pyx_t_11 = __pyx_t_12;
+ __pyx_L6_bool_binop_done:;
+ __pyx_t_12 = ((!__pyx_t_11) != 0);
+ if (__pyx_t_12) {
+
+ /* "triangle_hash.pyx":67
+ * y = int(points[i_point, 1])
+ * if not (0 <= x < self.resolution and 0 <= y < self.resolution):
+ * continue # <<<<<<<<<<<<<<
+ *
+ * spatial_idx = self.resolution * x + y
+ */
+ goto __pyx_L3_continue;
+
+ /* "triangle_hash.pyx":66
+ * x = int(points[i_point, 0])
+ * y = int(points[i_point, 1])
+ * if not (0 <= x < self.resolution and 0 <= y < self.resolution): # <<<<<<<<<<<<<<
+ * continue
+ *
+ */
+ }
+
+ /* "triangle_hash.pyx":69
+ * continue
+ *
+ * spatial_idx = self.resolution * x + y # <<<<<<<<<<<<<<
+ * for i_tri in self.spatial_hash[spatial_idx]:
+ * points_indices.push_back(i_point)
+ */
+ __pyx_v_spatial_idx = ((__pyx_v_self->resolution * __pyx_v_x) + __pyx_v_y);
+
+ /* "triangle_hash.pyx":70
+ *
+ * spatial_idx = self.resolution * x + y
+ * for i_tri in self.spatial_hash[spatial_idx]: # <<<<<<<<<<<<<<
+ * points_indices.push_back(i_point)
+ * tri_indices.push_back(i_tri)
+ */
+ __pyx_t_15 = &(__pyx_v_self->spatial_hash[__pyx_v_spatial_idx]);
+ __pyx_t_14 = __pyx_t_15->begin();
+ for (;;) {
+ if (!(__pyx_t_14 != __pyx_t_15->end())) break;
+ __pyx_t_16 = *__pyx_t_14;
+ ++__pyx_t_14;
+ __pyx_v_i_tri = __pyx_t_16;
+
+ /* "triangle_hash.pyx":71
+ * spatial_idx = self.resolution * x + y
+ * for i_tri in self.spatial_hash[spatial_idx]:
+ * points_indices.push_back(i_point) # <<<<<<<<<<<<<<
+ * tri_indices.push_back(i_tri)
+ *
+ */
+ try {
+ __pyx_v_points_indices.push_back(__pyx_v_i_point);
+ } catch(...) {
+ __Pyx_CppExn2PyErr();
+ __PYX_ERR(0, 71, __pyx_L1_error)
+ }
+
+ /* "triangle_hash.pyx":72
+ * for i_tri in self.spatial_hash[spatial_idx]:
+ * points_indices.push_back(i_point)
+ * tri_indices.push_back(i_tri) # <<<<<<<<<<<<<<
+ *
+ * points_indices_np = np.zeros(points_indices.size(), dtype=np.int32)
+ */
+ try {
+ __pyx_v_tri_indices.push_back(__pyx_v_i_tri);
+ } catch(...) {
+ __Pyx_CppExn2PyErr();
+ __PYX_ERR(0, 72, __pyx_L1_error)
+ }
+
+ /* "triangle_hash.pyx":70
+ *
+ * spatial_idx = self.resolution * x + y
+ * for i_tri in self.spatial_hash[spatial_idx]: # <<<<<<<<<<<<<<
+ * points_indices.push_back(i_point)
+ * tri_indices.push_back(i_tri)
+ */
+ }
+ __pyx_L3_continue:;
+ }
+
+ /* "triangle_hash.pyx":74
+ * tri_indices.push_back(i_tri)
+ *
+ * points_indices_np = np.zeros(points_indices.size(), dtype=np.int32) # <<<<<<<<<<<<<<
+ * tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32)
+ *
+ */
+ __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_zeros); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __pyx_t_1 = __Pyx_PyInt_FromSize_t(__pyx_v_points_indices.size()); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_4);
+ __Pyx_GIVEREF(__pyx_t_1);
+ PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_1);
+ __pyx_t_1 = 0;
+ __pyx_t_1 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_np); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_3, __pyx_n_s_int32); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_5);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ if (PyDict_SetItem(__pyx_t_1, __pyx_n_s_dtype, __pyx_t_5) < 0) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
+ __pyx_t_5 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_4, __pyx_t_1); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 74, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_5);
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+ __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __pyx_v_points_indices_np = __pyx_t_5;
+ __pyx_t_5 = 0;
+
+ /* "triangle_hash.pyx":75
+ *
+ * points_indices_np = np.zeros(points_indices.size(), dtype=np.int32)
+ * tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32) # <<<<<<<<<<<<<<
+ *
+ * cdef int[:] points_indices_view = points_indices_np
+ */
+ __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_5);
+ __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_zeros); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
+ __pyx_t_5 = __Pyx_PyInt_FromSize_t(__pyx_v_tri_indices.size()); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_5);
+ __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_4);
+ __Pyx_GIVEREF(__pyx_t_5);
+ PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_5);
+ __pyx_t_5 = 0;
+ __pyx_t_5 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_5);
+ __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_np); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_int32); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+ if (PyDict_SetItem(__pyx_t_5, __pyx_n_s_dtype, __pyx_t_3) < 0) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_3 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_4, __pyx_t_5); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 75, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
+ __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
+ __pyx_v_tri_indices_np = __pyx_t_3;
+ __pyx_t_3 = 0;
+
+ /* "triangle_hash.pyx":77
+ * tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32)
+ *
+ * cdef int[:] points_indices_view = points_indices_np # <<<<<<<<<<<<<<
+ * cdef int[:] tri_indices_view = tri_indices_np
+ *
+ */
+ __pyx_t_17 = __Pyx_PyObject_to_MemoryviewSlice_ds_int(__pyx_v_points_indices_np, PyBUF_WRITABLE); if (unlikely(!__pyx_t_17.memview)) __PYX_ERR(0, 77, __pyx_L1_error)
+ __pyx_v_points_indices_view = __pyx_t_17;
+ __pyx_t_17.memview = NULL;
+ __pyx_t_17.data = NULL;
+
+ /* "triangle_hash.pyx":78
+ *
+ * cdef int[:] points_indices_view = points_indices_np
+ * cdef int[:] tri_indices_view = tri_indices_np # <<<<<<<<<<<<<<
+ *
+ * for k in range(points_indices.size()):
+ */
+ __pyx_t_17 = __Pyx_PyObject_to_MemoryviewSlice_ds_int(__pyx_v_tri_indices_np, PyBUF_WRITABLE); if (unlikely(!__pyx_t_17.memview)) __PYX_ERR(0, 78, __pyx_L1_error)
+ __pyx_v_tri_indices_view = __pyx_t_17;
+ __pyx_t_17.memview = NULL;
+ __pyx_t_17.data = NULL;
+
+ /* "triangle_hash.pyx":80
+ * cdef int[:] tri_indices_view = tri_indices_np
+ *
+ * for k in range(points_indices.size()): # <<<<<<<<<<<<<<
+ * points_indices_view[k] = points_indices[k]
+ *
+ */
+ __pyx_t_18 = __pyx_v_points_indices.size();
+ __pyx_t_19 = __pyx_t_18;
+ for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_19; __pyx_t_6+=1) {
+ __pyx_v_k = __pyx_t_6;
+
+ /* "triangle_hash.pyx":81
+ *
+ * for k in range(points_indices.size()):
+ * points_indices_view[k] = points_indices[k] # <<<<<<<<<<<<<<
+ *
+ * for k in range(tri_indices.size()):
+ */
+ __pyx_t_9 = __pyx_v_k;
+ *((int *) ( /* dim=0 */ (__pyx_v_points_indices_view.data + __pyx_t_9 * __pyx_v_points_indices_view.strides[0]) )) = (__pyx_v_points_indices[__pyx_v_k]);
+ }
+
+ /* "triangle_hash.pyx":83
+ * points_indices_view[k] = points_indices[k]
+ *
+ * for k in range(tri_indices.size()): # <<<<<<<<<<<<<<
+ * tri_indices_view[k] = tri_indices[k]
+ *
+ */
+ __pyx_t_18 = __pyx_v_tri_indices.size();
+ __pyx_t_19 = __pyx_t_18;
+ for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_19; __pyx_t_6+=1) {
+ __pyx_v_k = __pyx_t_6;
+
+ /* "triangle_hash.pyx":84
+ *
+ * for k in range(tri_indices.size()):
+ * tri_indices_view[k] = tri_indices[k] # <<<<<<<<<<<<<<
+ *
+ * return points_indices_np, tri_indices_np
+ */
+ __pyx_t_9 = __pyx_v_k;
+ *((int *) ( /* dim=0 */ (__pyx_v_tri_indices_view.data + __pyx_t_9 * __pyx_v_tri_indices_view.strides[0]) )) = (__pyx_v_tri_indices[__pyx_v_k]);
+ }
+
+ /* "triangle_hash.pyx":86
+ * tri_indices_view[k] = tri_indices[k]
+ *
+ * return points_indices_np, tri_indices_np # <<<<<<<<<<<<<<
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 86, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_INCREF(__pyx_v_points_indices_np);
+ __Pyx_GIVEREF(__pyx_v_points_indices_np);
+ PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_points_indices_np);
+ __Pyx_INCREF(__pyx_v_tri_indices_np);
+ __Pyx_GIVEREF(__pyx_v_tri_indices_np);
+ PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_v_tri_indices_np);
+ __pyx_r = __pyx_t_3;
+ __pyx_t_3 = 0;
+ goto __pyx_L0;
+
+ /* "triangle_hash.pyx":51
+ * @cython.boundscheck(False) # Deactivate bounds checking
+ * @cython.wraparound(False) # Deactivate negative indexing.
+ * cpdef query(self, double[:, :] points): # <<<<<<<<<<<<<<
+ * assert(points.shape[1] == 2)
+ * cdef int n_points = points.shape[0]
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_XDECREF(__pyx_t_2);
+ __Pyx_XDECREF(__pyx_t_3);
+ __Pyx_XDECREF(__pyx_t_4);
+ __Pyx_XDECREF(__pyx_t_5);
+ __PYX_XDEC_MEMVIEW(&__pyx_t_17, 1);
+ __Pyx_AddTraceback("triangle_hash.TriangleHash.query", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XDECREF(__pyx_v_points_indices_np);
+ __Pyx_XDECREF(__pyx_v_tri_indices_np);
+ __PYX_XDEC_MEMVIEW(&__pyx_v_points_indices_view, 1);
+ __PYX_XDEC_MEMVIEW(&__pyx_v_tri_indices_view, 1);
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* Python wrapper */
+static PyObject *__pyx_pw_13triangle_hash_12TriangleHash_3query(PyObject *__pyx_v_self, PyObject *__pyx_arg_points); /*proto*/
+static PyObject *__pyx_pw_13triangle_hash_12TriangleHash_3query(PyObject *__pyx_v_self, PyObject *__pyx_arg_points) {
+ __Pyx_memviewslice __pyx_v_points = { 0, 0, { 0 }, { 0 }, { 0 } };
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ PyObject *__pyx_r = 0;
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("query (wrapper)", 0);
+ assert(__pyx_arg_points); {
+ __pyx_v_points = __Pyx_PyObject_to_MemoryviewSlice_dsds_double(__pyx_arg_points, PyBUF_WRITABLE); if (unlikely(!__pyx_v_points.memview)) __PYX_ERR(0, 51, __pyx_L3_error)
+ }
+ goto __pyx_L4_argument_unpacking_done;
+ __pyx_L3_error:;
+ __Pyx_AddTraceback("triangle_hash.TriangleHash.query", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __Pyx_RefNannyFinishContext();
+ return NULL;
+ __pyx_L4_argument_unpacking_done:;
+ __pyx_r = __pyx_pf_13triangle_hash_12TriangleHash_2query(((struct __pyx_obj_13triangle_hash_TriangleHash *)__pyx_v_self), __pyx_v_points);
+
+ /* function exit code */
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+static PyObject *__pyx_pf_13triangle_hash_12TriangleHash_2query(struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, __Pyx_memviewslice __pyx_v_points) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("query", 0);
+ __Pyx_XDECREF(__pyx_r);
+ if (unlikely(!__pyx_v_points.memview)) { __Pyx_RaiseUnboundLocalError("points"); __PYX_ERR(0, 51, __pyx_L1_error) }
+ __pyx_t_1 = __pyx_f_13triangle_hash_12TriangleHash_query(__pyx_v_self, __pyx_v_points, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 51, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("triangle_hash.TriangleHash.query", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = NULL;
+ __pyx_L0:;
+ __PYX_XDEC_MEMVIEW(&__pyx_v_points, 1);
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "(tree fragment)":1
+ * def __reduce_cython__(self): # <<<<<<<<<<<<<<
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__")
+ * def __setstate_cython__(self, __pyx_state):
+ */
+
+/* Python wrapper */
+static PyObject *__pyx_pw_13triangle_hash_12TriangleHash_5__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused); /*proto*/
+static PyObject *__pyx_pw_13triangle_hash_12TriangleHash_5__reduce_cython__(PyObject *__pyx_v_self, CYTHON_UNUSED PyObject *unused) {
+ PyObject *__pyx_r = 0;
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0);
+ __pyx_r = __pyx_pf_13triangle_hash_12TriangleHash_4__reduce_cython__(((struct __pyx_obj_13triangle_hash_TriangleHash *)__pyx_v_self));
+
+ /* function exit code */
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+static PyObject *__pyx_pf_13triangle_hash_12TriangleHash_4__reduce_cython__(CYTHON_UNUSED struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("__reduce_cython__", 0);
+
+ /* "(tree fragment)":2
+ * def __reduce_cython__(self):
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__") # <<<<<<<<<<<<<<
+ * def __setstate_cython__(self, __pyx_state):
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__")
+ */
+ __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple_, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 2, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_Raise(__pyx_t_1, 0, 0, 0);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __PYX_ERR(1, 2, __pyx_L1_error)
+
+ /* "(tree fragment)":1
+ * def __reduce_cython__(self): # <<<<<<<<<<<<<<
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__")
+ * def __setstate_cython__(self, __pyx_state):
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("triangle_hash.TriangleHash.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = NULL;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "(tree fragment)":3
+ * def __reduce_cython__(self):
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__")
+ * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<<
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__")
+ */
+
+/* Python wrapper */
+static PyObject *__pyx_pw_13triangle_hash_12TriangleHash_7__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state); /*proto*/
+static PyObject *__pyx_pw_13triangle_hash_12TriangleHash_7__setstate_cython__(PyObject *__pyx_v_self, PyObject *__pyx_v___pyx_state) {
+ PyObject *__pyx_r = 0;
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0);
+ __pyx_r = __pyx_pf_13triangle_hash_12TriangleHash_6__setstate_cython__(((struct __pyx_obj_13triangle_hash_TriangleHash *)__pyx_v_self), ((PyObject *)__pyx_v___pyx_state));
+
+ /* function exit code */
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+static PyObject *__pyx_pf_13triangle_hash_12TriangleHash_6__setstate_cython__(CYTHON_UNUSED struct __pyx_obj_13triangle_hash_TriangleHash *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("__setstate_cython__", 0);
+
+ /* "(tree fragment)":4
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__")
+ * def __setstate_cython__(self, __pyx_state):
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__") # <<<<<<<<<<<<<<
+ */
+ __pyx_t_1 = __Pyx_PyObject_Call(__pyx_builtin_TypeError, __pyx_tuple__2, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_Raise(__pyx_t_1, 0, 0, 0);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __PYX_ERR(1, 4, __pyx_L1_error)
+
+ /* "(tree fragment)":3
+ * def __reduce_cython__(self):
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__")
+ * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<<
+ * raise TypeError("no default __reduce__ due to non-trivial __cinit__")
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("triangle_hash.TriangleHash.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = NULL;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":734
+ * ctypedef npy_cdouble complex_t
+ *
+ * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(1, a)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew1(PyObject *__pyx_v_a) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew1", 0);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":735
+ *
+ * cdef inline object PyArray_MultiIterNew1(a):
+ * return PyArray_MultiIterNew(1, a) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew2(a, b):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(1, ((void *)__pyx_v_a)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 735, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":734
+ * ctypedef npy_cdouble complex_t
+ *
+ * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(1, a)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew1", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":737
+ * return PyArray_MultiIterNew(1, a)
+ *
+ * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(2, a, b)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew2(PyObject *__pyx_v_a, PyObject *__pyx_v_b) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew2", 0);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":738
+ *
+ * cdef inline object PyArray_MultiIterNew2(a, b):
+ * return PyArray_MultiIterNew(2, a, b) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew3(a, b, c):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(2, ((void *)__pyx_v_a), ((void *)__pyx_v_b)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 738, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":737
+ * return PyArray_MultiIterNew(1, a)
+ *
+ * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(2, a, b)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew2", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":740
+ * return PyArray_MultiIterNew(2, a, b)
+ *
+ * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(3, a, b, c)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew3(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew3", 0);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":741
+ *
+ * cdef inline object PyArray_MultiIterNew3(a, b, c):
+ * return PyArray_MultiIterNew(3, a, b, c) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew4(a, b, c, d):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(3, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 741, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":740
+ * return PyArray_MultiIterNew(2, a, b)
+ *
+ * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(3, a, b, c)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew3", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":743
+ * return PyArray_MultiIterNew(3, a, b, c)
+ *
+ * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(4, a, b, c, d)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew4(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew4", 0);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":744
+ *
+ * cdef inline object PyArray_MultiIterNew4(a, b, c, d):
+ * return PyArray_MultiIterNew(4, a, b, c, d) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(4, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 744, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":743
+ * return PyArray_MultiIterNew(3, a, b, c)
+ *
+ * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(4, a, b, c, d)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew4", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":746
+ * return PyArray_MultiIterNew(4, a, b, c, d)
+ *
+ * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(5, a, b, c, d, e)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew5(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d, PyObject *__pyx_v_e) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew5", 0);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":747
+ *
+ * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e):
+ * return PyArray_MultiIterNew(5, a, b, c, d, e) # <<<<<<<<<<<<<<
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(5, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d), ((void *)__pyx_v_e)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 747, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":746
+ * return PyArray_MultiIterNew(4, a, b, c, d)
+ *
+ * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(5, a, b, c, d, e)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew5", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":749
+ * return PyArray_MultiIterNew(5, a, b, c, d, e)
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<<
+ * if PyDataType_HASSUBARRAY(d):
+ * return d.subarray.shape
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyDataType_SHAPE(PyArray_Descr *__pyx_v_d) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ int __pyx_t_1;
+ __Pyx_RefNannySetupContext("PyDataType_SHAPE", 0);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":750
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d):
+ * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<<
+ * return d.subarray.shape
+ * else:
+ */
+ __pyx_t_1 = (PyDataType_HASSUBARRAY(__pyx_v_d) != 0);
+ if (__pyx_t_1) {
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":751
+ * cdef inline tuple PyDataType_SHAPE(dtype d):
+ * if PyDataType_HASSUBARRAY(d):
+ * return d.subarray.shape # <<<<<<<<<<<<<<
+ * else:
+ * return ()
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __Pyx_INCREF(((PyObject*)__pyx_v_d->subarray->shape));
+ __pyx_r = ((PyObject*)__pyx_v_d->subarray->shape);
+ goto __pyx_L0;
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":750
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d):
+ * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<<
+ * return d.subarray.shape
+ * else:
+ */
+ }
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":753
+ * return d.subarray.shape
+ * else:
+ * return () # <<<<<<<<<<<<<<
+ *
+ *
+ */
+ /*else*/ {
+ __Pyx_XDECREF(__pyx_r);
+ __Pyx_INCREF(__pyx_empty_tuple);
+ __pyx_r = __pyx_empty_tuple;
+ goto __pyx_L0;
+ }
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":749
+ * return PyArray_MultiIterNew(5, a, b, c, d, e)
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<<
+ * if PyDataType_HASSUBARRAY(d):
+ * return d.subarray.shape
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":928
+ * int _import_umath() except -1
+ *
+ * cdef inline void set_array_base(ndarray arr, object base): # <<<<<<<<<<<<<<
+ * Py_INCREF(base) # important to do this before stealing the reference below!
+ * PyArray_SetBaseObject(arr, base)
+ */
+
+static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_arr, PyObject *__pyx_v_base) {
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("set_array_base", 0);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":929
+ *
+ * cdef inline void set_array_base(ndarray arr, object base):
+ * Py_INCREF(base) # important to do this before stealing the reference below! # <<<<<<<<<<<<<<
+ * PyArray_SetBaseObject(arr, base)
+ *
+ */
+ Py_INCREF(__pyx_v_base);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":930
+ * cdef inline void set_array_base(ndarray arr, object base):
+ * Py_INCREF(base) # important to do this before stealing the reference below!
+ * PyArray_SetBaseObject(arr, base) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object get_array_base(ndarray arr):
+ */
+ (void)(PyArray_SetBaseObject(__pyx_v_arr, __pyx_v_base));
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":928
+ * int _import_umath() except -1
+ *
+ * cdef inline void set_array_base(ndarray arr, object base): # <<<<<<<<<<<<<<
+ * Py_INCREF(base) # important to do this before stealing the reference below!
+ * PyArray_SetBaseObject(arr, base)
+ */
+
+ /* function exit code */
+ __Pyx_RefNannyFinishContext();
+}
+
+/* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":932
+ * PyArray_SetBaseObject(arr, base)
+ *
+ * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<<
+ * base = PyArray_BASE(arr)
+ * if base is NULL:
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__pyx_v_arr) {
+ PyObject *__pyx_v_base;
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ int __pyx_t_1;
+ __Pyx_RefNannySetupContext("get_array_base", 0);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":933
+ *
+ * cdef inline object get_array_base(ndarray arr):
+ * base = PyArray_BASE(arr) # <<<<<<<<<<<<<<
+ * if base is NULL:
+ * return None
+ */
+ __pyx_v_base = PyArray_BASE(__pyx_v_arr);
+
+ /* "../../../../miniconda/envs/econ/lib/python3.8/site-packages/numpy/__init__.pxd":934
+ * cdef inline object get_array_base(ndarray arr):
+ * base = PyArray_BASE(arr)
+ * if base is NULL: # <<<<<<<<<<<<<<
+ * return None
+ * return