目录
- 3D Gaussian splatting 01: 环境搭建
- 3D Gaussian splatting 02: 快速评估
- 3D Gaussian splatting 03: 用户数据训练和结果查看
- 3D Gaussian splatting 04: 代码阅读-提取相机位姿和稀疏点云
- 3D Gaussian splatting 05: 代码阅读-训练整体流程
- 3D Gaussian splatting 06: 代码阅读-训练参数
- 3D Gaussian splatting 07: 代码阅读-训练载入数据和保存结果
- 3D Gaussian splatting 08: 代码阅读-渲染
训练载入数据
在 train.py 中载入数据对应的方法调用栈如下, 因为convert.py预处理使用的是colmap, 读取数据最终调用的是 readColmapSceneInfo 方法
Scene(dataset, gaussians)
└─sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
└─readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8)
读取流程是
- 从 images.bin, cameras.bin 读取相机参数和每一帧的位姿
- 区分训练集和测试集
- 从 points3D.bin 读取3D点云
def read_points3D_binary(path_to_model_file):
"""
see: src/base/reconstruction.cc
void Reconstruction::ReadPoints3DBinary(const std::string& path)
void Reconstruction::WritePoints3DBinary(const std::string& path)
"""
with open(path_to_model_file, "rb") as fid:
num_points = read_next_bytes(fid, 8, "Q")[0]
# 创建未初始化的 n * 3 数组, 随机值
xyzs = np.empty((num_points, 3))
rgbs = np.empty((num_points, 3))
errors = np.empty((num_points, 1))
for p_id in range(num_points):
binary_point_line_properties = read_next_bytes(
fid, num_bytes=43, format_char_sequence="QdddBBBd")
xyz = np.array(binary_point_line_properties[1:4])
rgb = np.array(binary_point_line_properties[4:7])
error = np.array(binary_point_line_properties[7])
track_length = read_next_bytes(
fid, num_bytes=8, format_char_sequence="Q")[0]
track_elems = read_next_bytes(
fid, num_bytes=8*track_length,
format_char_sequence="ii"*track_length)
xyzs[p_id] = xyz
rgbs[p_id] = rgb
errors[p_id] = error
return xyzs, rgbs, errors
里面用到的read_next_bytes
方法, 读取一段二进制字节, 使用struct.unpack
按指定的格式, 转为对应的变量
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
"""Read and unpack the next bytes from a binary file.
:param fid:
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
:param endian_character: Any of {@, =, <, >, !}
:return: Tuple of read and unpacked values.
"""
data = fid.read(num_bytes)
return struct.unpack(endian_character + format_char_sequence, data)
在 readColmapSceneInfo() 方法中, 如果设置了--eval
参数, 会将cam_names 排序后, 按序号与 llffhold 求余是否为0分为训练集和测试集. llffhold 值为8, 所以训练集与测试集的比例为 7:1. 如果没有指定, 则全部数据作为训练集. 如果要手工指定测试集, 可以在 sparse/0 下创建一个 test.txt, 将参数 llffhold 的默认值改为0.
if eval:
if "360" in path:
llffhold = 8
if llffhold:
print("------------LLFF HOLD-------------")
cam_names = [cam_extrinsics[cam_id].name for cam_id in cam_extrinsics]
cam_names = sorted(cam_names)
test_cam_names_list = [name for idx, name in enumerate(cam_names) if idx % llffhold == 0]
else:
with open(os.path.join(path, "sparse/0", "test.txt"), 'r') as file:
test_cam_names_list = [line.strip() for line in file]
再下面会判断是否有 points3D.ply, 存在就读取, 不存在就创建一个再读取
ply_path = os.path.join(path, "sparse/0/points3D.ply")
bin_path = os.path.join(path, "sparse/0/points3D.bin")
txt_path = os.path.join(path, "sparse/0/points3D.txt")
if not os.path.exists(ply_path):
print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
try:
xyz, rgb, _ = read_points3D_binary(bin_path)
except:
xyz, rgb, _ = read_points3D_text(txt_path)
storePly(ply_path, xyz, rgb)
try:
pcd = fetchPly(ply_path)
读取出来的是 BasicPointCloud 类型的数据
BasicPointCloud
BasicPointCloud 用于表示三维点云的基础数据结构, 包含坐标、颜色和法线信息
class BasicPointCloud(NamedTuple):
points : np.array
colors : np.array
normals : np.array
def geom_transform_points(points, transf_matrix):
# 将点转换为齐次坐标后应用变换矩阵, 返回经过投影变换后的三维坐标, PyTorch实现的齐次坐标变换,支持批量变换操作
P, _ = points.shape
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
points_hom = torch.cat([points, ones], dim=1)
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
denom = points_out[..., 3:] + 0.0000001
return (points_out[..., :3] / denom).squeeze(dim=0)
def getWorld2View(R, t):
# 创建世界坐标系到相机坐标系的4x4变换矩阵 R: 3x3旋转矩阵,t: 3D平移向量
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
return np.float32(Rt)
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
# 增强版视图变换,支持场景平移和缩放, 通过相机到世界坐标系的逆变换实现
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
C2W = np.linalg.inv(Rt)
cam_center = C2W[:3, 3]
cam_center = (cam_center + translate) * scale
C2W[:3, 3] = cam_center
Rt = np.linalg.inv(C2W)
return np.float32(Rt)
def getProjectionMatrix(znear, zfar, fovX, fovY):
# 生成透视投影矩阵 参数包含近/远裁剪面,水平和垂直视场角 返回4x4投影矩阵
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
def fov2focal(fov, pixels):
# 视场角转焦距(单位:像素)
return pixels / (2 * math.tan(fov / 2))
def focal2fov(focal, pixels):
# 焦距转视场角
return 2*math.atan(pixels/(2*focal))
训练结果数据结构
安装 pyntcloud
pip install pyntcloud
查看 ply 文件
>>> from pyntcloud import PyntCloud
>>> cloud = PyntCloud.from_file("output/1ed8e6a1-9/point_cloud/iteration_7000/point_cloud.ply")
>>> print(cloud)
PyntCloud
743269 points with 59 scalar fields
0 faces in mesh
0 kdtrees
0 voxelgrids
Centroid: 1.6537141799926758, -2.9306182861328125, -4.471662521362305
>>> type(cloud.points)
<class 'pandas.core.frame.DataFrame'>
点的数据类型是 DataFrame, 查看第一个点的属性列, 每一项都是float32/4个字节, 但是属性太多被省略了
>>> print(cloud.points.loc[0])
x 1.947371
y -0.500535
z 1.388533
nx 0.000000
ny 0.000000
...
scale_2 -4.380099
rot_0 0.840099
rot_1 -0.143527
rot_2 0.065419
rot_3 0.179504
Name: 0, Length: 62, dtype: float32
此去掉rows限制, 就可以打印全貌了
>>> pd.set_option('display.max_rows', None)
>>> print(cloud.points.loc[0])
x 1.947371
y -0.500535
z 1.388533
nx 0.000000
ny 0.000000
nz 0.000000
f_dc_0 -0.264158
f_dc_1 0.352959
f_dc_2 0.361867
f_rest_0 0.012889
f_rest_1 -0.001385
f_rest_2 0.044487
f_rest_3 0.013909
# 省略 f_rest_ 开头的字段
f_rest_41 -0.038870
f_rest_42 -0.015730
f_rest_43 0.042109
f_rest_44 0.021378
opacity -1.817663
scale_0 -5.108221
scale_1 -4.811676
scale_2 -4.380099
rot_0 0.840099
rot_1 -0.143527
rot_2 0.065419
rot_3 0.179504
Name: 0, dtype: float32
结果数据输出的时候是通过拼接参数产生的
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
对应的属性含义
- x, y, z: 3D点云位置坐标
- nx, ny, nz: 未使用
- f_dc_0 - f_dc_2, f_rest_0 - f_rest_44: 颜色特征的DC分量和剩余分量, 3阶一共16个RGB球谐系数
- opacity: 不透明度参数
- scale_0 - scale_2: 缩放参数
- rot_0 - rot_3: 旋转参数