[3D CV ์ฐ๊ตฌ] 3DGS SH2RGB
Spherical Harmonics (SH)์ DC ์ฑ๋ถ
- Spherical Harmonics(SH)์ ๊ตฌ๋ฉด ์ขํ๊ณ์์ ์ ์๋๋ ํจ์์ ์งํฉ์ผ๋ก, 3D ๊ทธ๋ํฝ์ค์์ ์กฐ๋ช , ๋ฐ์ฌ, ์ํฅ ๋ฑ์ ํํํ๋๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
 - SH ๊ณ์๋ SH Level์ ๋ฐ๋ผ DC ์ฑ๋ถ๊ณผ ๋๋จธ์ง ์ฑ๋ถ์ผ๋ก ๋๋ฉ๋๋ค.
 - ์๋ ์ฝ๋์์ ์ค์ฌ์ฉ ์์๋ฅผ ๋จผ์  ์ดํดํด๋ด ์๋ค.
 
feateaures = concat([features_dc, features_rest], dim=1) # (n_points, sh ๊ณ์, RGB)
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
    @property
    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)
features_dc # (n_points, 1, RGB)features_rest # (n_points, 15, RGB)featrues_dc์features_rest๋ฅผ sh ๊ณ์ ์ฐจ์์ผ๋ก concat ํฉ๋๋ค.
RGB 3 channel๋ง๋ค n๋ฒ์งธ sh ๊ณ์์ ๋ํ ์ฑ๋์ด ์กด์ฌํฉ๋๋ค.
features_dc[:, 0, 0] = n_points, R channel, sh 0๋ฒ์งธ ๊ณ์,"f_dc_0"์ผ๋ก ๋ณ์์ด๋ฆ ์ ์features_dc[:, 0, 1] = n_points, G channel, sh 0๋ฒ์งธ ๊ณ์,"f_dc_1"์ผ๋ก ๋ณ์์ด๋ฆ ์ ์features_dc[:, 0, 2] = n_points, B channel, sh 0๋ฒ์งธ ๊ณ์,"f_dc_2"์ผ๋ก ๋ณ์์ด๋ฆ ์ ์features_dc์features_extra๋ฅผ ๋ถ๋ฌ์ฌ ๋๋ shape์ด(n_points, RGB, sh ๊ณ์)๋ก ์ ์๋ฉ๋๋ค.- ํ์ง๋ง ํ์ต์์ ์ฌ์ฉ๋  ๋, 
features_dc,features_extra๋ชจ๋transpose(1, 2)ํ์ฌself_features_dc,self._features_rest๋ก ์ ์ํฉ๋๋ค. self_features_dc,self._features_rest์ shape์(n_points, sh ๊ณ์, RGB)๋ก ๋ฐ๋๋๋ค.
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
    def load_ply(self, path):
...
        features_dc = np.zeros((xyz.shape[0], 3, 1))
        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
        extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
        assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
        for idx, attr_name in enumerate(extra_f_names):
            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
        features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
...
        self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
...
(n_points, sh ๊ณ์, RGB)๋ก ์ฐจ์์ด ๋ฐ๋self._features_dc,self._features_rest๋sh ๊ณ์์ฐจ์์ธdim=1์์concatํฉ๋๋ค.
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
    @property
    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)
f_dc, self._features_dc, f_rest, self._features_rest๋ (n_points, sh ๊ณ์, RGB) ํํ๋ก ํ์ต์ ์ฌ์ฉ๋ฉ๋๋ค.
self._features_dc,self._features_rest๋ฅผ ์ ์ํ๋ ๋ถ๋ถ์ ๋ด ์๋ค.- 
    
random_initilization์ผ๋ก pcd๋ฅผ 100,000๊ฐ ์ ์ํ์์ ๋,features.shape # (n_points, RGB, sh ๊ณ์) = (100000, 3, 16) - 
    
self._features_dc๋features[:,:,0:1].transpose(1, 2) # (n_points, RGB, sh 0๋ฒ์งธ ๊ณ์).transpose(1, 2) -> (n_points, sh 0๋ฒ์งธ ๊ณ์, RGB) = (100000, 1, 3) - 
    
self._features_rest๋features[:,:,1:].transpose(1, 2) # (n_points, RGB, sh 1~15๋ฒ์งธ ๊ณ์).transpose(1, 2) -> (n_points, sh 1~15๋ฒ์งธ ๊ณ์, RGB) = (100000, 15, 3) 
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
        self.spatial_lr_scale = spatial_lr_scale
        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
        features[:, :3, 0 ] = fused_color
        features[:, 3:, 1:] = 0.0
        print("Number of points at initialisation : ", fused_point_cloud.shape[0])
        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        rots[:, 0] = 1
        opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
...
์ต์ข
์ ์ผ๋ก self._features_dc๊ณผ self._features_rest๋ sh ๊ณ์ ์ฑ๋์์ concatํ์ฌ ํ์ต์ ์ฌ์ฉ๋ฉ๋๋ค.
class GaussianModel:
...
    @property
    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)
- 
    
features_dc.shape # (n_points, sh 0๋ฒ์งธ ๊ณ์, RGB) = (100000, 1, 3) - 
    
features_rest # (n_points, sh 1~15๋ฒ์งธ ๊ณ์, RGB) = (100000, 15, 3) - 
    
torch.cat((features_dc, features_rest), dim=1) # (n_points, sh 0~16๋ฒ์งธ ๊ณ์, RBG) = (100000, 16, 3) 
ํ์ต๋ self._features_dc, self._features_rest๋ฅผ ์ ์ฅํ  ๋๋ ๋ค์ # (n_points, RGB, sh ๊ณ์)๋ก shape์ ๋ง๋ค์ด ์ ์ฅํฉ๋๋ค.
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
    def save_ply(self, path):
        mkdir_p(os.path.dirname(path))
        xyz = self._xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        opacities = self._opacity.detach().cpu().numpy()
        scale = self._scaling.detach().cpu().numpy()
        rotation = self._rotation.detach().cpu().numpy()
        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)
initial self._features_dc # (n_points, sh 0๋ฒ์งธ ๊ณ์, RGB) = (100000, 1, 3)
self._features_dc # (optimized n_points, sh 0๋ฒ์งธ ๊ณ์, RGB) = (317737, 1 3)
self._features_dc.detach().transpose(1, 2) # (optimized n_points, RGB, sh 0๋ฒ์งธ ๊ณ์) = (317737, 3, 1)
self._features_dc.detach().transpose(1, 2).flatten(start_dim=1) # (optimized n_points, RGB * sh 0๋ฒ์งธ ๊ณ์) = (317737, 3)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()์ด๋ฏ๋ก
initial self._features_rest # (n_points, sh 1~15๋ฒ์งธ ๊ณ์, RGB) = (100000, 15, 3)
self._features_rest # (optimized n_points, sh 1~15๋ฒ์งธ ๊ณ์, RGB) = (317737, 15 3)
self._features_rest.detach().transpose(1, 2) # (optimized n_points, RGB, sh 1~15๋ฒ์งธ ๊ณ์) = (317737, 3, 15)
self._features_rest.detach().transpose(1, 2).flatten(start_dim=1) # (optimized n_points, RGB * sh 1~15๋ฒ์งธ ๊ณ์) = (317737, 45)
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()์ด๋ฏ๋ก
DC ์ฑ๋ถ์ด๋?
- DC ์ฑ๋ถ (Direct Current ์ฑ๋ถ): SH ํํ์ ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ์ฃผํ์ ์ฑ๋ถ์ผ๋ก, ์์ ํจ์์ ๋๋ค. ์ด๋ ๊ตฌ๋ฉด์์ ๋ชจ๋ ๋ฐฉํฅ์ ๋ํด ์ผ์ ํ ๊ฐ์ ๊ฐ์ง๋ฉฐ, 0์ฐจ ๊ตฌ๋ฉด ์กฐํ ํจ์๋ก ๋ณผ ์ ์์ต๋๋ค.
 - ์ญํ :
    
- ๊ธฐ๋ณธ ์กฐ๋ช ํํ: ๋ฐฉํฅ์ ๊ด๊ณ์์ด ์ผ์ ํ ์กฐ๋ช ์ ๋ํ๋ ๋๋ค.
 - ๋ฒ ์ด์ค๋ผ์ธ ์ค์ : ๊ณ ์ฐจ ๊ตฌ๋ฉด ์กฐํ ํจ์์ ํจ๊ป ์ฌ์ฉ๋์ด ๋ฐฉํฅ์ ๋ฐ๋ฅธ ์กฐ๋ช ๋ณํ๋ฅผ ๋ ์ ํํํ ์ ์๋๋ก ํฉ๋๋ค.
 
 
SH ๊ณ์์ ๊ตฌ์กฐ
๊ฐ SH Level๋ง๋ค DC ์ฑ๋ถ๊ณผ rest ์ฑ๋ถ์ด ์กด์ฌํ๋ฉฐ, RGB ์ฑ๋๋ณ๋ก ๊ฐ๊ฐ์ ๊ณ์๊ฐ ๋ ๋ฆฝ์ ์ผ๋ก ์กด์ฌํฉ๋๋ค.
Level 0
- DC ์ฑ๋ถ: 1๊ฐ (RGB ๊ฐ๊ฐ 1๊ฐ์ฉ ์ด 3๊ฐ)
 - Rest ์ฑ๋ถ: ์์
 - ์ด ๊ณ์: 1๊ฐ์ DC ์ฑ๋ถ x 3 (RGB ์ฑ๋) = 3๊ฐ
 
Level 1
- DC ์ฑ๋ถ: 1๊ฐ (RGB ๊ฐ๊ฐ 1๊ฐ์ฉ ์ด 3๊ฐ)
 - Rest ์ฑ๋ถ: 3๊ฐ (๊ฐ RGB ์ฑ๋๋ง๋ค 3๊ฐ์ ์ถ๊ฐ ๊ณ์)
 - ์ด ๊ณ์: 1๊ฐ์ DC ์ฑ๋ถ + 3๊ฐ์ rest ์ฑ๋ถ = 4๊ฐ
    
- RGB ๊ฐ๊ฐ์ ๋ํด 4๊ฐ์ ๊ณ์ = 4 x 3 = 12๊ฐ
 
 
Level 2
- DC ์ฑ๋ถ: 1๊ฐ (RGB ๊ฐ๊ฐ 1๊ฐ์ฉ ์ด 3๊ฐ)
 - Rest ์ฑ๋ถ: 8๊ฐ (๊ฐ RGB ์ฑ๋๋ง๋ค 8๊ฐ์ ์ถ๊ฐ ๊ณ์)
 - ์ด ๊ณ์: 1๊ฐ์ DC ์ฑ๋ถ + 8๊ฐ์ rest ์ฑ๋ถ = 9๊ฐ
    
- RGB ๊ฐ๊ฐ์ ๋ํด 9๊ฐ์ ๊ณ์ = 9 x 3 = 27๊ฐ
 
 
Level 3
- DC ์ฑ๋ถ: 1๊ฐ (RGB ๊ฐ๊ฐ 1๊ฐ์ฉ ์ด 3๊ฐ)
 - Rest ์ฑ๋ถ: 15๊ฐ (๊ฐ RGB ์ฑ๋๋ง๋ค 15๊ฐ์ ์ถ๊ฐ ๊ณ์)
 - ์ด ๊ณ์: 1๊ฐ์ DC ์ฑ๋ถ + 15๊ฐ์ rest ์ฑ๋ถ = 16๊ฐ
    
- RGB ๊ฐ๊ฐ์ ๋ํด 16๊ฐ์ ๊ณ์ = 16 x 3 = 48๊ฐ
 
 
์์ฝ
- DC ์ฑ๋ถ: ๊ธฐ๋ณธ์ ์ธ SH ์ฑ๋ถ์ผ๋ก, ๋ชจ๋ ๋ฐฉํฅ์์ ์ผ์ ํ ๊ฐ์ ๊ฐ์ง.
 - Rest ์ฑ๋ถ: DC ์ฑ๋ถ์ ์ ์ธํ ๋๋จธ์ง ๊ณ ์ฐจ SH ์ฑ๋ถ๋ค.
 - SH Levels์ ๊ณ์: ๊ฐ Level์ ๋ฐ๋ผ DC ์ฑ๋ถ์ ํฌํจํ ์ด ๊ณ์์ ์๊ฐ ๊ฒฐ์ ๋จ.
 - RGB ์ฑ๋๋ณ ๋ ๋ฆฝ์  ๊ณ์: ๊ฐ RGB ์ฑ๋์ ๋ํด SH ๊ณ์๋ค์ด ๊ฐ๋ณ์ ์ผ๋ก ์กด์ฌํ์ฌ ์ปฌ๋ฌ ์ ๋ณด๋ฅผ ์ ํํ ํํํจ.
 
open3d mesh๋ฅผ ๋ถ๋ฌ์์ initializeํ๋ ๋ฒ์ ์์๋ด ์๋ค.
- surface_mesh_to_bind๋ o3d mesh์ด๊ณ
 - n_points = surface_mesh_to_bind์ triangle ์ * triangle๋น gaussian ์ ์ ๋๋ค.
 - ์ฆ, n_points๋ triangle๋ค ์์ ์ ์ํ gaussian๋ค์ ์ด ๊ฐ์๋ฅผ ์๋ฏธํฉ๋๋ค.
 
    @property
    def n_points(self):
        if not self.binded_to_surface_mesh:
            return len(self._points)
        else:
            return self._n_points
n_points์์ spherical harmonics (sh)์ dc์ rest๋ฅผ ์ ์ํ๋ ์ฝ๋๋ฅผ ๋ด ์๋ค.
colors # shape (n_vertices, n_coords)์ ๋๋ค.- ์ด๋ 
n_coords๋vertices์ color ์ ๋ณด์ ๋๋ค!! ์ ํํ ์ฃผ์์ ํํํ๋ฉดn_coords๊ฐ ์๋๋ผrgb๋ก ํ์ํด์ผ ํฉ๋๋ค. - 
    
ํ์ง๋ง
colors์ rgb 3์ฐจ์๊ณผ,vertices์ xyz 3์ฐจ์ ์ ๋ณด๊ฐ ์ฐจ์์ด ๊ฐ์์ ์ฃผ์๋n_coords๋ก ํต์ผํ ๊ฒ์ผ๋ก ๋ณด์ ๋๋ค.colors # shape (n_vertices, n_coords) <-- vertices์ ๋ํ rgb 3์ฐจ์ color ๊ฐ = n_coords๋ก ํํ vertices # shape (n_vertices, n_coords) <-- vertices์ ๋ํ xyz 3์ฐจ์ ์ขํ๊ฐ = n_coords๋ก ํํ colors์์sh_coordinates_dc๋ฅผ ๋ง๋๋ ๊ณผ์ ์ ๋ด ์๋ค.
def RGB2SH(rgb):
    return (rgb - 0.5) / C0
def SH2RGB(sh):
    return sh * C0 + 0.5
sh_coordinates_dc = RGB2SH(colors).unsqueeze(dim=1) # shape (n_vertices, 1, n_coords)- ์ฃผ์์ ์ ๋๋ก ์ฐ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
 sh_coordinates_dc = RGB2SH(colors).unsqueeze(dim=1) # shape (n_vertices, 1, rgb)- 
    
์ฆ,
๋ง์ง๋ง rgb 3๊ฐ์ ์ฐจ์์ ๋ํด,sh_coordiantes_dc๋ ๊ฐ๊ฐ sh๋ฅผ 1๊ฐ์ฉ ๊ฐ์ง๊ฒ ๋ฉ๋๋ค. self._sh_coordinates_rest๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.self._sh_coordinates_rest = torch.zeros(n_points, sh_levels**2 - 1, 3) # shape (n_points, sh_levels**2 - 1, 3)- ์ฆ, 
๋ง์ง๋ง rgb 3๊ฐ์ ์ฐจ์์ ๋ํด,self._sh_coordinates_rest๋ ๊ฐ๊ฐsh_levels**2-1 ๊ฐ๋งํผ์ ๊ฐ์ง๊ฒ ๋ฉ๋๋ค. - ์ฆ, 
rgb ๊ฐ ์ฑ๋๋ณ๋กsh ๊ณ์๊ฐ ํ ๋น๋ฉ๋๋ค. 
        # Initialize color features
        self.sh_levels = sh_levels
        sh_coordinates_dc = RGB2SH(colors).unsqueeze(dim=1)
        self._sh_coordinates_dc = nn.Parameter(
            sh_coordinates_dc.to(self.nerfmodel.device),
            requires_grad=True and (not freeze_gaussians)
        ).to(self.nerfmodel.device)
        
        self._sh_coordinates_rest = nn.Parameter(
            torch.zeros(n_points, sh_levels**2 - 1, 3).to(self.nerfmodel.device),
            requires_grad=True and (not freeze_gaussians)
        ).to(self.nerfmodel.device)
# gaussian-splatting/scene/gaussian_model.py
class GaussianModel:
...
    @property
    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)
3dgs์์ _features_dc, _features_rest๋ spherical harmonics์ ๊ณ์์ ํด๋ฑํฉ๋๋ค.
- ๋ฐ๋ผ์ ์ด๋ฅผ SuGaR์์๋ ๋ณ์์ด๋ฆ์ ๊ทธ๋ฅ 
_sh_coordinates_dc[...],_sh_coordinates_rest[...]๋ก ๋ถ๋ฌ์์ ์ฌ์ฉํฉ๋๋ค. 
# SuGaR/sugar_extractors/coarse_mesh.py
def extract_mesh_from_coarse_sugar(args):
...
            sugar._sh_coordinates_dc[...] = nerfmodel.gaussians._features_dc.detach()
            sugar._sh_coordinates_rest[...] = nerfmodel.gaussians._features_rest.detach()
    
3dgs์์ _features_dc๋ RGB์ ํด๋นํ๋ sh(spherical harmonics)๊ณ์์
๋๋ค.
- ๋ฐ๋ผ์ ์ด๋ฅผ SuGaR์์๋ 
SH2RGBํจ์๋ก sh๋ฅผ rgb๋ก ๋ณํํ์ฌcolors๋ณ์์ ํ ๋นํฉ๋๋ค. 
# SuGaR/sugar_extractors/coarse_mesh.py
def extract_mesh_from_coarse_sugar(args):
...
        CONSOLE.print(f"\nLoading the coarse SuGaR model from path {sugar_checkpoint_path}...")
        checkpoint = torch.load(sugar_checkpoint_path, map_location=nerfmodel.device)
        colors = SH2RGB(checkpoint['state_dict']['_sh_coordinates_dc'][:, 0, :])
        sugar = SuGaR(
            nerfmodel=nerfmodel,
            points=checkpoint['state_dict']['_points'],
            colors=colors,
            initialize=True,
            sh_levels=nerfmodel.gaussians.active_sh_degree+1,
            keep_track_of_knn=True,
            knn_to_track=16,
            beta_mode='average',  # 'learnable', 'average', 'weighted_average'
            primitive_types='diamond',  # 'diamond', 'square'
            surface_mesh_to_bind=None,  # Open3D mesh
            )
        sugar.load_state_dict(checkpoint['state_dict'])
    sugar.eval()
- ๊ทธ๋ฆฌ๊ณ  ์ด 
colors๋ ๋ถ๋ฌ์์RGB2SHํจ์๋ก ๋ณํํ๊ณunsqueezeํ์ฌ ์ฐจ์์ ๋ง์ถฐ ํ์ต์ initialize๋ก ์ค ์ ์์ต๋๋ค. 
SH ๊ณ์๋ฅผ RGB๋ก ๋ณํํ๊ธฐ ์ํด ํ์ํ ์ ๋ณด
SH ๊ณ์๋ฅผ RGB๋ก ๋ณํํ๊ธฐ ์ํด์๋ ๋ค์์ ์ ๋ณด๊ฐ ํ์ํฉ๋๋ค: sh์ deg, sh์ ๊ณ์, camera center์์ point๊น์ง์ direction.
- SuGaR์์๋ 
eval_sh(deg, sh, dirs)๋ฅผ ์ฌ์ฉํ์ฌ ํน์  ์นด๋ฉ๋ผ ๋ฐฉํฅ์์ ํฌ์ธํธ๊น์ง์ ๋ ๋๋ง ๋ฐฉํฅ์ ๋ฐ๋ผ SH ๊ณ์๋ฅผ ํ๋์ RGB ์ปฌ๋ฌ๋ก ๋ณํํฉ๋๋ค. - RGB 3๊ฐ์ ์ฑ๋๋ง๋ค SH ๊ณ์๊ฐ ๋ ๋ฆฝ์ ์ผ๋ก ์กด์ฌํฉ๋๋ค.
 - DC๋ ์กฐ๋ช ์ ์ ์ฒด์ ์ธ ๋ฐ๊ธฐ๋ฅผ ๋ํ๋ด๋ ์์์ ๋๋ค.
 
SH์ DC (0๋ฒ์งธ Band)
sh[..., 0]์ 0๋ฒ์งธ deg์ ํด๋นํ๋ SH์ DC ๊ณ์๋ก, ๋ชจ๋ RGB ์ฑ๋์ ๋ํด ๋์ผํฉ๋๋ค.
SH์ Rest (1๋ฒ์งธ, 2๋ฒ์งธ, 3๋ฒ์งธ, โฆ)
sh[..., 1]๋ถํฐsh[..., 24]๊น์ง์ ๊ณ์๋ ๊ฐ๋๊ฐ ์๋ ์กฐ๋ช ๊ตฌ์ฑ ์์์ ๋๋ค.- SH ๊ณ์๋ ๊ฐ RGB ์ฑ๋์ ๋ํด ๊ฐ๋ณ์ ์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค. ์ฆ, R, G, B ๊ฐ ์ฑ๋์ ์๋ก ๋ค๋ฅธ SH ๊ณ์๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค.
 
์์๋ก ์ค๋ช
์๋ฅผ ๋ค์ด, RGB ์ฑ๋ ๊ฐ๊ฐ์ ๋ํด SH ๊ณ์๊ฐ ๋ค์๊ณผ ๊ฐ์ด ์์ ์ ์์ต๋๋ค:
R ์ฑ๋
sh[..., 0] = 0.5
sh[..., 1] = 0.1
sh[..., 2] = 0.3
...
G ์ฑ๋
sh[..., 0] = 0.4
sh[..., 1] = 0.2
sh[..., 2] = 0.6
...
B ์ฑ๋
sh[..., 0] = 0.7
sh[..., 1] = 0.3
sh[..., 2] = 0.5
...
๊ฐ ๊ณ์๋ ๊ฐ RGB ์ฑ๋์ ๋ํด ๊ฐ๋ณ์ ์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค:
- sh[โฆ, 1]์ 1๋ฒ์งธ deg์ ํด๋นํ๋ SH์ ๊ณ์๋ก, ๊ฐ RGB ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋ ๊ฐ์ ๊ฐ์ง๋๋ค.
 - sh[โฆ, 2]์ 2๋ฒ์งธ deg์ ํด๋นํ๋ SH์ ๊ณ์๋ก, ๊ฐ RGB ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋ ๊ฐ์ ๊ฐ์ง๋๋ค.
 - sh[โฆ, 3]์ 3๋ฒ์งธ deg์ ํด๋นํ๋ SH์ ๊ณ์๋ก, ๊ฐ RGB ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋ ๊ฐ์ ๊ฐ์ง๋๋ค.
 - โฆ
 - sh[โฆ, 24]์ 24๋ฒ์งธ deg์ ํด๋นํ๋ SH์ ๊ณ์๋ก, ๊ฐ RGB ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋ ๊ฐ์ ๊ฐ์ง๋๋ค.
 
์์ฝ
- SH ๊ณ์๋ ๊ฐ RGB ์ฑ๋์ ๋ํด ๊ฐ๋ณ์ ์ผ๋ก ์กด์ฌํ๊ณ , ๊ณ์ฐ๋ฉ๋๋ค.
 - ์ด๋ ๋์ผํ ๊ฐ์ ๊ฐ์ง๋ค๋ ์๋ฏธ๊ฐ ์๋๋ฉฐ, ๊ฐ ์ฑ๋๋ณ๋ก ๋ค๋ฅธ ๊ฐ์ ๊ฐ์ง ์ ์์ต๋๋ค.
 - ๊ฐ ์ฑ๋์ SH ๊ณ์๋ ๊ฐ๋ณ์ ์ผ๋ก ์ฒ๋ฆฌ๋๋ฉฐ, eval_sh ํจ์๋ ๊ฐ ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก SH ๊ณ์๋ฅผ ์ฌ์ฉํ์ฌ RGB ๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
 
# SuGaR/sugar_scene/sugar_model.py
    def get_points_rgb(
        self,
        positions:torch.Tensor=None,
        camera_centers:torch.Tensor=None,
        directions:torch.Tensor=None,
        sh_levels:int=None,
        sh_coordinates:torch.Tensor=None,
        ):
        """Returns the RGB color of the points for the given camera pose.
        Args:
            positions (torch.Tensor, optional): Shape (n_pts, 3). Defaults to None.
            camera_centers (torch.Tensor, optional): Shape (n_pts, 3) or (1, 3). Defaults to None.
            directions (torch.Tensor, optional): _description_. Defaults to None.
        Raises:
            ValueError: _description_
        Returns:
            _type_: _description_
        """
            
        if positions is None:
            positions = self.points
        if camera_centers is not None:
            render_directions = torch.nn.functional.normalize(positions - camera_centers, dim=-1)
        elif directions is not None:
            render_directions = directions
        else:
            raise ValueError("Either camera_centers or directions must be provided.")
        if sh_coordinates is None:
            sh_coordinates = self.sh_coordinates
            
        if sh_levels is None:
            sh_coordinates = sh_coordinates
        else:
            sh_coordinates = sh_coordinates[:, :sh_levels**2]
        shs_view = sh_coordinates.transpose(-1, -2).view(-1, 3, sh_levels**2)
        sh2rgb = eval_sh(sh_levels-1, shs_view, render_directions)
        colors = torch.clamp_min(sh2rgb + 0.5, 0.0).view(-1, 3)
        
        return colors
# SuGaR/sugar_utils/spherical_harmonics.py
def eval_sh(deg, sh, dirs):
    """
    Evaluate spherical harmonics at unit directions
    using hardcoded SH polynomials.
    Works with torch/np/jnp.
    ... Can be 0 or more batch dimensions.
    Args:
        deg: int SH deg. Currently, 0-3 supported
        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
        dirs: jnp.ndarray unit directions [..., 3]
    Returns:
        [..., C]
    """
    assert deg <= 4 and deg >= 0
    coeff = (deg + 1) ** 2
    assert sh.shape[-1] >= coeff
    result = C0 * sh[..., 0]
    if deg > 0:
        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
        result = (result -
                C1 * y * sh[..., 1] +
                C1 * z * sh[..., 2] -
                C1 * x * sh[..., 3])
        if deg > 1:
            xx, yy, zz = x * x, y * y, z * z
            xy, yz, xz = x * y, y * z, x * z
            result = (result +
                    C2[0] * xy * sh[..., 4] +
                    C2[1] * yz * sh[..., 5] +
                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
                    C2[3] * xz * sh[..., 7] +
                    C2[4] * (xx - yy) * sh[..., 8])
            if deg > 2:
                result = (result +
                C3[0] * y * (3 * xx - yy) * sh[..., 9] +
                C3[1] * xy * z * sh[..., 10] +
                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
                C3[5] * z * (xx - yy) * sh[..., 14] +
                C3[6] * x * (xx - 3 * yy) * sh[..., 15])
                if deg > 3:
                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
    return result
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
    1.0925484305920792,
    -1.0925484305920792,
    0.31539156525252005,
    -1.0925484305920792,
    0.5462742152960396
]
C3 = [
    -0.5900435899266435,
    2.890611442640554,
    -0.4570457994644658,
    0.3731763325901154,
    -0.4570457994644658,
    1.445305721320277,
    -0.5900435899266435
]
C4 = [
    2.5033429417967046,
    -1.7701307697799304,
    0.9461746957575601,
    -0.6690465435572892,
    0.10578554691520431,
    -0.6690465435572892,
    0.47308734787878004,
    -1.7701307697799304,
    0.6258357354491761,
]
SH2RGB์ RGB2SH๋ sh์ 0๋ฒ์งธ band์ธ dc ์ฑ๋ถ์์๋ง ์ฌ์ฉ๋๋ ํจ์์ ๋๋ค.
def RGB2SH(rgb):
    return (rgb - 0.5) / C0
def SH2RGB(sh):
    return sh * C0 + 0.5
RGB2SH ํจ์
sh_coordinates_dc = RGB2SH(colors).unsqueeze(dim=1)
RGB2SH ํจ์๋RGB ๊ฐ์ ๋ฐ์์ SH ๊ณ์๋ก ๋ณํํฉ๋๋ค.๋ณํ๋ SH ๊ณ์๋ DC ์ฑ๋ถ์ ํด๋นํฉ๋๋ค.RGB2SH(colors):colors๋ RGB ๊ฐ์ ๋ํ๋ด๋ฉฐ,DC ์ฑ๋ถ์ ํด๋นํฉ๋๋ค.- ๋ฐ๋ผ์, ์ด ๋ณํ์ 
RGB ๊ฐ์ DC ์ฑ๋ถ์ SH ๊ณ์๋ก ๋ณํํฉ๋๋ค. .unsqueeze(dim=1): ์ฐจ์์ ์ถ๊ฐํ์ฌ SH ๊ณ์์ DC ์ฑ๋ถ์ 3D ํ ์๋ก ๋ง๋ญ๋๋ค.
SH2RGB ํจ์
    refined_sugar = SuGaR(
        nerfmodel=nerfmodel,
        points=checkpoint['state_dict']['_points'],
        colors=SH2RGB(checkpoint['state_dict']['_sh_coordinates_dc'][:, 0, :]), 
        initialize=False,
        sh_levels=nerfmodel.gaussians.active_sh_degree+1,
        keep_track_of_knn=False,
        knn_to_track=0,
        beta_mode='average',
        surface_mesh_to_bind=o3d_mesh,
        n_gaussians_per_surface_triangle=n_gaussians_per_surface_triangle,
        )
    refined_sugar.load_state_dict(checkpoint['state_dict'])
textures_uv = TexturesUV(
            maps=SH2RGB(self.texture_features[..., 0, :][None]), #texture_img[None]), 
            verts_uvs=self.verts_uv[None],
            faces_uvs=self.faces_uv[None],
            sampling_mode='nearest',
            )
SH2RGB ํจ์๋SH ๊ณ์๋ฅผ ๋ฐ์์ RGB ๊ฐ์ผ๋ก ๋ณํํฉ๋๋ค.- ์ด ํจ์๋ 
DC ์ฑ๋ถ์ ํด๋นํ๋ SH ๊ณ์๋ฅผ ๋ณํํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. self.texture_features[..., 0, :]์ SH ๊ณ์์ ์ฒซ ๋ฒ์งธ ์ฑ๋ถ(DC ์ฑ๋ถ)์ ์ ํํฉ๋๋ค.[None]์ ์ฌ์ฉํ์ฌ ์ฐจ์์ ์ถ๊ฐํฉ๋๋ค.SH2RGB ํจ์๋ ์ดDC ์ฑ๋ถ์ RGB ๊ฐ์ผ๋ก ๋ณํํฉ๋๋ค.
๊ฒฐ๋ก
RGB2SH์ SH2RGB๋ ์ฃผ์ด์ง ์ฝ๋์์ SH ๊ณ์์ DC ์ฑ๋ถ์ ํด๋นํ๋ ๊ฐ์ ๋ณํํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ๋ฐ๋ผ์, ์ด ๋ ํจ์๋ DC ์ฑ๋ถ์๋ง ํด๋นํ๋ค๊ณ  ๋ณผ ์ ์์ต๋๋ค.
์ต์ข ์์ฝ
RGB2SH ํจ์๋DC ์ฑ๋ถ์ ํด๋นํ๋ RGB ๊ฐ์ SH ๊ณ์๋ก ๋ณํํฉ๋๋ค.SH2RGB ํจ์๋DC ์ฑ๋ถ์ ํด๋นํ๋ SH ๊ณ์๋ฅผ RGB ๊ฐ์ผ๋ก ๋ณํํฉ๋๋ค.- ์ฃผ์ด์ง ์ฝ๋์์๋ ์ด ๋ ํจ์๊ฐ SH ๊ณ์์ DC ์ฑ๋ถ์ ์ฃผ๋ก ์ฌ์ฉ๋๊ณ ์์ต๋๋ค.
 
[3D CV ์ฐ๊ตฌ] 3DGS input & output .ply properties & Meshlab Vert & Spherical Harmonics (SH) & Mesh
      
Leave a comment