3D Gaussian splatting 07: 代码阅读-训练载入数据和保存结果

目录

训练载入数据

在 train.py 中载入数据对应的方法调用栈如下, 因为convert.py预处理使用的是colmap, 读取数据最终调用的是 readColmapSceneInfo 方法

Scene(dataset, gaussians)
└─sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
  └─readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8)

读取流程是

  1. 从 images.bin, cameras.bin 读取相机参数和每一帧的位姿
  2. 区分训练集和测试集
  3. 从 points3D.bin 读取3D点云
def read_points3D_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadPoints3DBinary(const std::string& path)
        void Reconstruction::WritePoints3DBinary(const std::string& path)
    """

    with open(path_to_model_file, "rb") as fid:
        num_points = read_next_bytes(fid, 8, "Q")[0]

        # 创建未初始化的 n * 3 数组, 随机值
        xyzs = np.empty((num_points, 3))
        rgbs = np.empty((num_points, 3))
        errors = np.empty((num_points, 1))

        for p_id in range(num_points):
            binary_point_line_properties = read_next_bytes(
                fid, num_bytes=43, format_char_sequence="QdddBBBd")
            xyz = np.array(binary_point_line_properties[1:4])
            rgb = np.array(binary_point_line_properties[4:7])
            error = np.array(binary_point_line_properties[7])
            track_length = read_next_bytes(
                fid, num_bytes=8, format_char_sequence="Q")[0]
            track_elems = read_next_bytes(
                fid, num_bytes=8*track_length,
                format_char_sequence="ii"*track_length)
            xyzs[p_id] = xyz
            rgbs[p_id] = rgb
            errors[p_id] = error
    return xyzs, rgbs, errors

里面用到的read_next_bytes方法, 读取一段二进制字节, 使用struct.unpack按指定的格式, 转为对应的变量

def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
    """Read and unpack the next bytes from a binary file.
    :param fid:
    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
    :param endian_character: Any of {@, =, <, >, !}
    :return: Tuple of read and unpacked values.
    """
    data = fid.read(num_bytes)
    return struct.unpack(endian_character + format_char_sequence, data)

在 readColmapSceneInfo() 方法中, 如果设置了--eval参数, 会将cam_names 排序后, 按序号与 llffhold 求余是否为0分为训练集和测试集. llffhold 值为8, 所以训练集与测试集的比例为 7:1. 如果没有指定, 则全部数据作为训练集. 如果要手工指定测试集, 可以在 sparse/0 下创建一个 test.txt, 将参数 llffhold 的默认值改为0.

if eval:
    if "360" in path:
        llffhold = 8
    if llffhold:
        print("------------LLFF HOLD-------------")
        cam_names = [cam_extrinsics[cam_id].name for cam_id in cam_extrinsics]
        cam_names = sorted(cam_names)
        test_cam_names_list = [name for idx, name in enumerate(cam_names) if idx % llffhold == 0]
    else:
        with open(os.path.join(path, "sparse/0", "test.txt"), 'r') as file:
            test_cam_names_list = [line.strip() for line in file]

再下面会判断是否有 points3D.ply, 存在就读取, 不存在就创建一个再读取

    ply_path = os.path.join(path, "sparse/0/points3D.ply")
    bin_path = os.path.join(path, "sparse/0/points3D.bin")
    txt_path = os.path.join(path, "sparse/0/points3D.txt")
    if not os.path.exists(ply_path):
        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
        try:
            xyz, rgb, _ = read_points3D_binary(bin_path)
        except:
            xyz, rgb, _ = read_points3D_text(txt_path)
        storePly(ply_path, xyz, rgb)
    try:
        pcd = fetchPly(ply_path)

读取出来的是 BasicPointCloud 类型的数据

BasicPointCloud

BasicPointCloud 用于表示三维点云的基础数据结构, 包含坐标、颜色和法线信息

class BasicPointCloud(NamedTuple):
    points : np.array
    colors : np.array
    normals : np.array

def geom_transform_points(points, transf_matrix):
    # 将点转换为齐次坐标后应用变换矩阵, 返回经过投影变换后的三维坐标, PyTorch实现的齐次坐标变换,支持批量变换操作
    P, _ = points.shape
    ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
    points_hom = torch.cat([points, ones], dim=1)
    points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))

    denom = points_out[..., 3:] + 0.0000001
    return (points_out[..., :3] / denom).squeeze(dim=0)

def getWorld2View(R, t):
    # 创建世界坐标系到相机坐标系的4x4变换矩阵 R: 3x3旋转矩阵,t: 3D平移向量
    Rt = np.zeros((4, 4))
    Rt[:3, :3] = R.transpose()
    Rt[:3, 3] = t
    Rt[3, 3] = 1.0
    return np.float32(Rt)

def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
    # 增强版视图变换,支持场景平移和缩放, 通过相机到世界坐标系的逆变换实现
    Rt = np.zeros((4, 4))
    Rt[:3, :3] = R.transpose()
    Rt[:3, 3] = t
    Rt[3, 3] = 1.0

    C2W = np.linalg.inv(Rt)
    cam_center = C2W[:3, 3]
    cam_center = (cam_center + translate) * scale
    C2W[:3, 3] = cam_center
    Rt = np.linalg.inv(C2W)
    return np.float32(Rt)

def getProjectionMatrix(znear, zfar, fovX, fovY):
    # 生成透视投影矩阵 参数包含近/远裁剪面,水平和垂直视场角 返回4x4投影矩阵
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P

def fov2focal(fov, pixels):
    # 视场角转焦距(单位:像素)
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    # 焦距转视场角
    return 2*math.atan(pixels/(2*focal))

训练结果数据结构

安装 pyntcloud

pip install pyntcloud

查看 ply 文件

>>> from pyntcloud import PyntCloud
>>> cloud = PyntCloud.from_file("output/1ed8e6a1-9/point_cloud/iteration_7000/point_cloud.ply")
>>> print(cloud)
PyntCloud
743269 points with 59 scalar fields
0 faces in mesh
0 kdtrees
0 voxelgrids
Centroid: 1.6537141799926758, -2.9306182861328125, -4.471662521362305

>>> type(cloud.points)
<class 'pandas.core.frame.DataFrame'>

点的数据类型是 DataFrame, 查看第一个点的属性列, 每一项都是float32/4个字节, 但是属性太多被省略了

>>> print(cloud.points.loc[0])
x          1.947371
y         -0.500535
z          1.388533
nx         0.000000
ny         0.000000
             ...   
scale_2   -4.380099
rot_0      0.840099
rot_1     -0.143527
rot_2      0.065419
rot_3      0.179504
Name: 0, Length: 62, dtype: float32

此去掉rows限制, 就可以打印全貌了

>>> pd.set_option('display.max_rows', None)
>>> print(cloud.points.loc[0])
x            1.947371
y           -0.500535
z            1.388533
nx           0.000000
ny           0.000000
nz           0.000000
f_dc_0      -0.264158
f_dc_1       0.352959
f_dc_2       0.361867
f_rest_0     0.012889
f_rest_1    -0.001385
f_rest_2     0.044487
f_rest_3     0.013909
# 省略 f_rest_ 开头的字段
f_rest_41   -0.038870
f_rest_42   -0.015730
f_rest_43    0.042109
f_rest_44    0.021378
opacity     -1.817663
scale_0     -5.108221
scale_1     -4.811676
scale_2     -4.380099
rot_0        0.840099
rot_1       -0.143527
rot_2        0.065419
rot_3        0.179504
Name: 0, dtype: float32

结果数据输出的时候是通过拼接参数产生的

attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)

对应的属性含义

  • x, y, z: 3D点云位置坐标
  • nx, ny, nz: 未使用
  • f_dc_0 - f_dc_2, f_rest_0 - f_rest_44: 颜色特征的DC分量和剩余分量, 3阶一共16个RGB球谐系数
  • opacity: 不透明度参数
  • scale_0 - scale_2: 缩放参数
  • rot_0 - rot_3: 旋转参数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值