Giter Site home page Giter Site logo

Comments (5)

ashawkey avatar ashawkey commented on July 20, 2024 2

@Spark001 Hi, the code is not cleaned and you need to modify it to your case:

import numpy as np
import trimesh
from sklearn.neighbors import NearestNeighbors
from scipy.spatial.transform import Rotation as Rot
import os, json
import torch
import cubvh
import matplotlib.pyplot as plt

from packaging import version as pver

import logging
logger = logging.getLogger("trimesh")
logger.setLevel(logging.ERROR)

def custom_meshgrid(*args):
    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
    if pver.parse(torch.__version__) < pver.parse('1.10'):
        return torch.meshgrid(*args)
    else:
        return torch.meshgrid(*args, indexing='ij')

W = 800
H = 800

# ref: https://gist.github.com/sergeyprokudin/c4bf4059230da8db8256e36524993367
def chamfer_distance(x, y, metric='l2', direction='bi'):
    
    if direction=='y_to_x':
        x_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(x)
        min_y_to_x = x_nn.kneighbors(y)[0]
        chamfer_dist = np.mean(min_y_to_x)
    elif direction=='x_to_y':
        y_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(y)
        min_x_to_y = y_nn.kneighbors(x)[0]
        chamfer_dist = np.mean(min_x_to_y)
    elif direction=='bi':
        x_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(x)
        min_y_to_x = x_nn.kneighbors(y)[0]
        y_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(y)
        min_x_to_y = y_nn.kneighbors(x)[0]
        chamfer_dist = (np.mean(min_y_to_x) + np.mean(min_x_to_y)) / 2 # modified to keep scale
    else:
        raise ValueError("Invalid direction type. Supported types: \'y_x\', \'x_y\', \'bi\'")
        
    return chamfer_dist


@torch.cuda.amp.autocast(enabled=False)
def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, coords=None):
    ''' get rays
    Args:
        poses: [N/1, 4, 4], cam2world
        intrinsics: [N/1, 4] tensor or [4] ndarray
        H, W, N: int
    Returns:
        rays_o, rays_d: [N, 3]
        i, j: [N]
    '''

    device = poses.device
    
    if isinstance(intrinsics, np.ndarray):
        fx, fy, cx, cy = intrinsics
    else:
        fx, fy, cx, cy = intrinsics[:, 0], intrinsics[:, 1], intrinsics[:, 2], intrinsics[:, 3]

    i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float
    i = i.t().contiguous().view(-1) + 0.5
    j = j.t().contiguous().view(-1) + 0.5

    results = {}

    if N > 0:
       
        if coords is not None:
            inds = coords[:, 0] * W + coords[:, 1]

        elif patch_size > 1:

            # random sample left-top cores.
            num_patch = N // (patch_size ** 2)
            inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)
            inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)
            inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]

            # create meshgrid for each patch
            pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device))
            offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2]

            inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2]
            inds = inds.view(-1, 2) # [N, 2]
            inds = inds[:, 0] * W + inds[:, 1] # [N], flatten


        else: # random sampling
            inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate

        i = torch.gather(i, -1, inds)
        j = torch.gather(j, -1, inds)

        results['i'] = i.long()
        results['j'] = j.long()

    else:
        inds = torch.arange(H*W, device=device)

    zs = -torch.ones_like(i) # z is flipped
    xs = (i - cx) / fx
    ys = -(j - cy) / fy # y is flipped
    directions = torch.stack((xs, ys, zs), dim=-1) # [N, 3]
    # do not normalize to get actual depth, ref: https://github.com/dunbar12138/DSNeRF/issues/29
    # directions = directions / torch.norm(directions, dim=-1, keepdim=True) 
    rays_d = (directions.unsqueeze(1) @ poses[:, :3, :3].transpose(-1, -2)).squeeze(1) # [N, 1, 3] @ [N, 3, 3] --> [N, 1, 3]

    rays_o = poses[:, :3, 3].expand_as(rays_d) # [N, 3]

    results['rays_o'] = rays_o
    results['rays_d'] = rays_d

    # visualize_rays(rays_o[0].detach().cpu().numpy(), rays_d[0].detach().cpu().numpy())

    return results


def sample_surface(poses, intrinsics, mesh, N):

    # normalize
    vmin, vmax = mesh.bounds
    center = (vmin + vmax) / 2
    scale = 1 / (vmax - vmin)
    mesh.vertices = (mesh.vertices - center) * scale

    RT = cubvh.cuBVH(mesh.vertices, mesh.faces)

    # need to cast rays ...
    all_positions = []

    per_frame_n = N // len(poses)

    for pose in poses:
        
        pose = torch.from_numpy(pose).unsqueeze(0).cuda()
        rays = get_rays(pose, intrinsics, H, W, -1)
        rays_o = rays['rays_o'].contiguous().view(-1, 3)
        rays_d = rays['rays_d'].contiguous().view(-1, 3)

        positions, face_id, depth = RT.ray_trace(rays_o, rays_d)

        # depth = depth.detach().cpu().numpy().reshape(H, W, 1)
        # mask = depth >= 10
        # mn = depth[~mask].min()
        # mx = depth[~mask].max()
        # depth = (depth - mn) / (mx - mn + 1e-5)
        # depth[mask] = 0
        # depth = depth.repeat(3, -1)
        # plt.imshow(depth)
        # plt.show()

        mask = face_id >= 0
        positions = positions[mask].detach().cpu().numpy().reshape(-1, 3)

        indices = np.random.choice(len(positions), per_frame_n, replace=False)
        positions = positions[indices]

        all_positions.append(positions)

    all_positions = np.concatenate(all_positions, axis=0)

    # revert 
    all_positions = (all_positions / scale) + center

    # scene = trimesh.Scene([mesh, trimesh.PointCloud(all_positions)])
    # scene.show()

    return all_positions


def visualize_poses(poses, size=0.05, bound=1, mesh=None):
    # poses: [B, 4, 4]

    axes = trimesh.creation.axis(axis_length=4)
    box = trimesh.primitives.Box(extents=[2*bound]*3).as_outline()
    box.colors = np.array([[128, 128, 128]] * len(box.entities))
    objects = [axes, box]

    for pose in poses:
        # a camera is visualized with 8 line segments.
        pos = pose[:3, 3]
        a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
        b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
        c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
        d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]

        dir = (a + b + c + d) / 4 - pos
        dir = dir / (np.linalg.norm(dir) + 1e-8)
        o = pos + dir * 3

        segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]])
        segs = trimesh.load_path(segs)
        objects.append(segs)

    if mesh is not None:
        objects.append(mesh)

    scene = trimesh.Scene(objects)
    scene.show()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('pred', type=str)
    parser.add_argument('gt', type=str)
    parser.add_argument('--N', type=int, default=250000)
    parser.add_argument('--scale', type=float, default=0.8)
    parser.add_argument('--fix_pred_coord', action='store_true')
    parser.add_argument('--vis', action='store_true')
    opt = parser.parse_args()
    
    pred_mesh = trimesh.load(opt.pred, force='mesh', skip_material=True, process=False)
    gt_mesh = trimesh.load(opt.gt, force='mesh', skip_material=True)

    name = os.path.basename(opt.gt).replace('.obj', '')
    

    # fix gt coord
    v = gt_mesh.vertices
    R = Rot.from_euler('x', 90, degrees=True)
    v = R.apply(v)
    gt_mesh.vertices = v

    # fix my scale
    if opt.scale != 1:
        v = pred_mesh.vertices # [N, 3]
        v /= opt.scale
        pred_mesh.vertices = v

    if opt.fix_pred_coord: # for nvdiffrec's output
        v = pred_mesh.vertices
        R = Rot.from_euler('x', 90, degrees=True)
        v = R.apply(v)
        pred_mesh.vertices = v

    # scene = trimesh.Scene([pred_mesh, gt_mesh])
    # scene.show()

    root_path = f'/data/tang/mobile-ngp/data/nerf_synthetic/{name}/transforms_test.json'

    with open(root_path, 'r') as f:
        transform = json.load(f)
    
    frames = np.array(transform["frames"])
    poses = []
    for f in frames:
        pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4]
        poses.append(pose)
    poses = np.stack(poses, axis=0)

    # visualize_poses(poses, mesh=gt_mesh)
    
    # load intrinsics
    if 'fl_x' in transform or 'fl_y' in transform:
        fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y'])
        fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x'])
    elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
        # blender, assert in radians. already downscaled since we use H/W
        fl_x = W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
        fl_y = H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
        if fl_x is None: fl_x = fl_y
        if fl_y is None: fl_y = fl_x
    else:
        raise RuntimeError('Failed to load focal length, please check the transforms.json!')

    cx = (transform['cx']) if 'cx' in transform else (W / 2.0)
    cy = (transform['cy']) if 'cy' in transform else (H / 2.0)
    
    intrinsics = np.array([fl_x, fl_y, cx, cy])

    gt_points = sample_surface(poses, intrinsics, gt_mesh, opt.N)
    pred_points = sample_surface(poses, intrinsics, pred_mesh, opt.N)

    if opt.vis:
        gt_color = np.array([[0, 0, 255]], dtype=np.uint8).repeat(len(gt_points), 0)
        pred_color = np.array([[255, 0, 0]], dtype=np.uint8).repeat(len(pred_points), 0)
        scene = trimesh.Scene([trimesh.PointCloud(pred_points, pred_color), trimesh.PointCloud(gt_points, gt_color)])
        scene.show()

    cd = chamfer_distance(pred_points, gt_points, direction='bi')

    print(f'[CD] {name: <20} {cd:.6f}')

run like:

echo "============================="
# coarse mesh
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_chair/mesh_stage0/mesh_0.ply data/nerf-syn-gt-mesh/chair.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_drums/mesh_stage0/mesh_0.ply data/nerf-syn-gt-mesh/drums.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_ficus/mesh_stage0/mesh_0.ply data/nerf-syn-gt-mesh/ficus.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_hotdog/mesh_stage0/mesh_0.ply data/nerf-syn-gt-mesh/hotdog.obj --scale 0.7
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_lego/mesh_stage0/mesh_0.ply data/nerf-syn-gt-mesh/lego.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_materials/mesh_stage0/mesh_0.ply data/nerf-syn-gt-mesh/materials.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_mic/mesh_stage0/mesh_0.ply data/nerf-syn-gt-mesh/mic.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_ship/mesh_stage0/mesh_0.ply data/nerf-syn-gt-mesh/ship.obj --scale 0.7


echo "============================="
# fine mesh
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_chair/mesh_stage1/mesh_0.obj data/nerf-syn-gt-mesh/chair.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_drums/mesh_stage1/mesh_0.obj data/nerf-syn-gt-mesh/drums.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_ficus/mesh_stage1/mesh_0.obj data/nerf-syn-gt-mesh/ficus.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_hotdog/mesh_stage1/mesh_0.obj data/nerf-syn-gt-mesh/hotdog.obj --scale 0.7
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_lego/mesh_stage1/mesh_0.obj data/nerf-syn-gt-mesh/lego.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_materials/mesh_stage1/mesh_0.obj data/nerf-syn-gt-mesh/materials.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_mic/mesh_stage1/mesh_0.obj data/nerf-syn-gt-mesh/mic.obj --scale 0.8
python scripts/eval_chamfer_distance.py ../mobile-ngp-opensource/trial_syn_ship/mesh_stage1/mesh_0.obj data/nerf-syn-gt-mesh/ship.obj --scale 0.7


echo "============================="
# neus
python scripts/eval_chamfer_distance.py data/neus_mesh/chair.obj data/nerf-syn-gt-mesh/chair.obj --scale 1
python scripts/eval_chamfer_distance.py data/neus_mesh/drums.obj data/nerf-syn-gt-mesh/drums.obj --scale 1
python scripts/eval_chamfer_distance.py data/neus_mesh/ficus.obj data/nerf-syn-gt-mesh/ficus.obj --scale 1
python scripts/eval_chamfer_distance.py data/neus_mesh/hotdog.obj data/nerf-syn-gt-mesh/hotdog.obj --scale 1
python scripts/eval_chamfer_distance.py data/neus_mesh/lego.obj data/nerf-syn-gt-mesh/lego.obj --scale 1
python scripts/eval_chamfer_distance.py data/neus_mesh/materials.obj data/nerf-syn-gt-mesh/materials.obj --scale 1
python scripts/eval_chamfer_distance.py data/neus_mesh/mic.obj data/nerf-syn-gt-mesh/mic.obj --scale 1
python scripts/eval_chamfer_distance.py data/neus_mesh/ship.obj data/nerf-syn-gt-mesh/ship.obj --scale 1


echo "============================="
# nvdiffrec
python scripts/eval_chamfer_distance.py ../nvdiffrec_original/out/nerf_chair/mesh/mesh.obj data/nerf-syn-gt-mesh/chair.obj --scale 1 --fix_pred_coord
python scripts/eval_chamfer_distance.py ../nvdiffrec_original/out/nerf_drums/mesh/mesh.obj data/nerf-syn-gt-mesh/drums.obj --scale 1 --fix_pred_coord
python scripts/eval_chamfer_distance.py ../nvdiffrec_original/out/nerf_ficus/mesh/mesh.obj data/nerf-syn-gt-mesh/ficus.obj --scale 1 --fix_pred_coord
python scripts/eval_chamfer_distance.py ../nvdiffrec_original/out/nerf_hotdog/mesh/mesh.obj data/nerf-syn-gt-mesh/hotdog.obj --scale 1 --fix_pred_coord
python scripts/eval_chamfer_distance.py ../nvdiffrec_original/out/nerf_lego/mesh/mesh.obj data/nerf-syn-gt-mesh/lego.obj --scale 1 --fix_pred_coord
python scripts/eval_chamfer_distance.py ../nvdiffrec_original/out/nerf_materials/mesh/mesh.obj data/nerf-syn-gt-mesh/materials.obj --scale 1 --fix_pred_coord
python scripts/eval_chamfer_distance.py ../nvdiffrec_original/out/nerf_mic/mesh/mesh.obj data/nerf-syn-gt-mesh/mic.obj --scale 1 --fix_pred_coord
python scripts/eval_chamfer_distance.py ../nvdiffrec_original/out/nerf_ship/mesh/mesh.obj data/nerf-syn-gt-mesh/ship.obj --scale 1 --fix_pred_coord

from nerf2mesh.

WB-3 avatar WB-3 commented on July 20, 2024

@ashawkey Hello, is nerf-syn-gt-mesh derived from this:
image

from nerf2mesh.

ashawkey avatar ashawkey commented on July 20, 2024

@WB-3 Yes, we export obj files from the blender projects.

from nerf2mesh.

LiuShuai086 avatar LiuShuai086 commented on July 20, 2024

@WB-3 Can you share the link of nerf-syn-gt-mesh? I cannot find it on google drive.

from nerf2mesh.

LiuShuai086 avatar LiuShuai086 commented on July 20, 2024

What's more, Could you please provide the corresponding gt-mesh for mipnerf360 and dtu datasets, if available? Thank you.

from nerf2mesh.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.