3 minute read

3D gaussian์˜ x,y,z์˜ scale์„ ์ดˆ๊ธฐํ™” ํ• ๋•Œ, K-Nearest Neighbor (knn) ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์‚ฌ์šฉ

  • 3๊ฐœ์˜ nearest neighbor points๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ‰๊ท ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ , ๊ทธ ๊ฐ’์œผ๋กœ 3D gaussian์˜ scale์„ isotropic์œผ๋กœ ์ดˆ๊ธฐํ™” ํ•ฉ๋‹ˆ๋‹ค.
  • knn์€ simple_knn ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œspatial.cu์—์„œ SimpleKNN::knn๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ distCUDA2๋กœ ๊ตฌํ˜„๋˜์–ด ์žˆ๋Š” ๊ฒƒ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ```cuda #include โ€œspatial.hโ€ #include โ€œsimple_knn.hโ€

torch::Tensor distCUDA2(const torch::Tensor& points) { const int P = points.size(0);

auto float_opts = points.options().dtype(torch::kFloat32); torch::Tensor means = torch::full({P}, 0.0, float_opts);

SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data());

return means; }

- `SimpleKNN::knn`์€ `simple_knn.cu`์— ์ •์˜๋˜์–ด ์žˆ์œผ๋ฉฐ, ์ตœ๊ทผ์ ‘ `K`๊ฐœ์˜ ์ด์›ƒ์„ ๊ตฌํ• ๋• `updateKBest<K>` ํ˜ธ์ถœ์„ ํ†ตํ•ด ์ด๋ฃจ์–ด์ง‘๋‹ˆ๋‹ค.
- ์ด๋•Œ `updateKBest<3>`์„ ํ˜ธ์ถœํ•˜๋„๋ก ํ•˜๋“œ์ฝ”๋”ฉ๋˜์–ด ์žˆ๊ณ  ์ด ๋ถ€๋ถ„์—์„œ **3๊ฐœ์˜ ์ตœ๊ทผ์ ‘ ์ด์›ƒ์„ ์‚ฌ์šฉ**ํ•˜๊ณ  ์žˆ์Œ์„ ํ™•์ธ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.
- ๊ฒฐ๋ก ์ ์œผ๋กœ ์•„๋ž˜ ์ฝ”๋“œ์—์„œ `dist2`๋Š” `distCUDA2`๋กœ point๋งˆ๋‹ค 3๊ฐœ์˜ nearest neighbor point์˜ ํ‰๊ท  ๊ฑฐ๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
- `dist2`๋กœ 3d gaussian์˜ scale์„ isotropicํ•˜๊ฒŒ initializeํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

![image](https://github.com/sandokim/sandokim.github.io/assets/74639652/031fa214-d612-487f-956c-bf2923c6695b)

```python
# 3dgs/scene/gaussian_model.py

from simple_knn._C import distCUDA2

class GaussianModel:

...

    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
        self.spatial_lr_scale = spatial_lr_scale
        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
        features[:, :3, 0 ] = fused_color
        features[:, 3:, 1:] = 0.0

        print("Number of points at initialisation : ", fused_point_cloud.shape[0])

        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        rots[:, 0] = 1

        opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))

        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
        self._scaling = nn.Parameter(scales.requires_grad_(True))
        self._rotation = nn.Parameter(rots.requires_grad_(True))
        self._opacity = nn.Parameter(opacities.requires_grad_(True))
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")

point cloud๋กœ๋ถ€ํ„ฐ n_points ๊ฐœ์ˆ˜๋งŒํผ RGB channel ๊ฐ๊ฐ์— ๋Œ€ํ•œ sh์˜ ๊ณ„์ˆ˜๋ฅผ ์ •์˜ํ•˜์—ฌ ์ดˆ๊ธฐํ™” ํ•ฉ๋‹ˆ๋‹ค.

  • fused_color๋Š” ์  ๊ตฌ๋ฆ„์˜ ์ƒ‰์ƒ ์ •๋ณด๋ฅผ Spherical Harmonics(SH)๋กœ ๋ณ€ํ™˜ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ๋ณ€ํ™˜์€ ์ผ๋ฐ˜์ ์œผ๋กœ ์  ๊ตฌ๋ฆ„์˜ ์ƒ‰์ƒ ์ •๋ณด๋ฅผ ๋” ์ž˜ ํ‘œํ˜„ํ•˜๊ณ , ๋‹ค์–‘ํ•œ ์กฐ๋ช… ์กฐ๊ฑด์—์„œ์˜ ๋ฐ˜์‘์„ ๋ชจ๋ธ๋งํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
  • ์—ฌ๊ธฐ์„œ RGB2SH๋Š” RGB ์ƒ‰์ƒ ๊ฐ’์„ Spherical Harmonics(SH) ๊ณ„์ˆ˜๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. ๋ณ€ํ™˜ ํ›„, fused_color๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค:
    • fused_color.shape: (num_points, 3, * (max_sh_degree + 1) ** 2) # (n_points, RGB, SH ๊ณ„์ˆ˜์˜ ์ˆ˜)
      • num_points: ์ ์˜ ๊ฐœ์ˆ˜
      • 3: RGB ์ƒ‰์ƒ ์ฑ„๋„ (Red, Green, Blue)
      • (max_sh_degree + 1) ** 2: SH ๊ณ„์ˆ˜์˜ ์ˆ˜
  • SH ๋ณ€ํ™˜์€ ์ฃผ๋กœ ์ €์ฐจ์ˆ˜์—์„œ ๊ณ ์ฐจ์ˆ˜๋กœ ์ง„ํ–‰๋˜๋ฉฐ, ๊ฐ ์ƒ‰์ƒ ์ฑ„๋„์— ๋Œ€ํ•ด ์—ฌ๋Ÿฌ ๊ณ„์ˆ˜๋ฅผ ๊ฐ–๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

features_dc, features_rest

  • features_dc๋Š” ๊ฐ ์ ์˜ ์ƒ‰์ƒ ์ •๋ณด๋ฅผ Spherical Harmonics ๊ณ„์ˆ˜๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ €์žฅํ•œ ํ…์„œ์ž…๋‹ˆ๋‹ค.
  • features_dc๋Š” ์ดˆ๊ธฐํ™” ์‹œ RGB ์ƒ‰์ƒ์˜ DC(Direct Component) ๊ณ„์ˆ˜๋งŒ์„ ํฌํ•จํ•˜๋ฉฐ, ๋‚˜๋จธ์ง€ SH ๊ณ„์ˆ˜๋“ค์€ features_rest์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.
  • features๋Š” (fused_colors.shape[0], 3, (self.max_sh_degree + 1) ** 2)์ธ (n_points, RGB 3 channel, sh ๊ณ„์ˆ˜์˜ ์ˆ˜)์˜ shape์œผ๋กœ 0์œผ๋กœ ์ดˆ๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค.
  • features[:, :3, 0 ] = fused_color๋Š” sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜์— ๋Œ€ํ•œ ๊ฐ’์„ RGB2SH๋กœ ๋„ฃ์–ด์ค๋‹ˆ๋‹ค.
  • features[:, 3:, 1:] = 0.0๋Š” ์‚ฌ์‹ค์ƒ RGB 3 channel์— ๋Œ€ํ•ด 3์ฐจ์› ์ด์ƒ์œผ๋กœ ์ธ๋ฑ์‹ฑ์ด ๋„˜์–ด๊ฐ”์œผ๋ฏ€๋กœ ์•„๋ฌดํšจ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ๋งˆ์ง€๋ง‰์— 1:์˜ ์ธ๋ฑ์‹ฑ์„ ๋ช…์‹œํ•˜์—ฌ ๋‹จ์ˆœํžˆ sh 1๋ฒˆ์งธ~44๋ฒˆ์งธ ๊ณ„์ˆ˜์— ๋Œ€ํ•œ ์ดˆ๊ธฐ๊ฐ’์ด 0.0์ด๋ผ๋Š” ๊ฒƒ์„ features๋ฅผ zeros๋กœ ์ดˆ๊ธฐํ™”ํ–ˆ์Œ์—๋„, ํ•œ๋ฒˆ ๋” ๋ช…์‹œํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์ž…๋‹ˆ๋‹ค.
# 3dgs/scene/gaussian_model.py

...

    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
        self.spatial_lr_scale = spatial_lr_scale
        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
        features[:, :3, 0 ] = fused_color # (n_points, RGB 3 channel, sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜)
        features[:, 3:, 1:] = 0.0 # (n_points, RGB 3 channel์˜ ์ธ๋ฑ์‹ฑ ๋ฒ”์œ„๋ฅผ ๋„˜์–ด๊ฐ, sh 1๋ฒˆ์งธ~44๋ฒˆ์งธ ๊ณ„์ˆ˜)

...

features_dc์—์„œ RGB 3 channel์— ๋Œ€ํ•œ sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜๋Š” f_dc_0, f_dc_1, f_dc_2์ž…๋‹ˆ๋‹ค.

  • load_ply๋ฅผ ๋ณด๋ฉด sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜์— ๋Œ€ํ•œ features_dc๋ฅผ RGB 3 channel์— ๋Œ€ํ•ด ์ดˆ๊ธฐํ™”ํ•  ๋•Œ, (n_points, RGB 3 channel, 0๋ฒˆ์งธ sh ๊ณ„์ˆ˜์˜ ์ˆ˜)๋ฅผ (n_points, 3, 1)๋กœ ์ดˆ๊ธฐํ™” ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ  RGB 3 channel์„ 0, 1, 2๋กœ ์ธ๋ฑ์‹ฑํ•˜์—ฌ sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜์— ๋Œ€ํ•œ๊ฐ’์œผ๋กœ ๊ฐ๊ฐ "f_dc_0", "f_dc_1", "f_dc_2"๋กœ ๋„ฃ์–ด์ค๋‹ˆ๋‹ค.
  • ์ •๋ฆฌํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
    • "f_dc_0"๋Š” R์— ํ•ด๋‹นํ•˜๋Š” sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜ ๊ฐ’
    • "f_dc_1"๋Š” G์— ํ•ด๋‹นํ•˜๋Š” sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜ ๊ฐ’
    • "f_dc_2"๋Š” B์— ํ•ด๋‹นํ•˜๋Š” sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜ ๊ฐ’

```python

3dgs/scene/gaussian_model.py

โ€ฆ

def load_ply(self, path):
    plydata = PlyData.read(path)

    xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                    np.asarray(plydata.elements[0]["y"]),
                    np.asarray(plydata.elements[0]["z"])),  axis=1)
    opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]

    features_dc = np.zeros((xyz.shape[0], 3, 1)) # (n_points, RGB 3 channel, (0 + 1) ** 2) = (n_points, 3, 1)
    features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) # (n_points, R, R์— ๋Œ€ํ•œ sh 0๋ฒˆ์ฉจ ๊ณ„์ˆ˜)
    features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) # (n_points, G, G์— ๋Œ€ํ•œ sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜)
    features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) # (n_points, B, B์— ๋Œ€ํ•œ sh 0๋ฒˆ์งธ ๊ณ„์ˆ˜) ...

Leave a comment