[3D CV ์ฐ๊ตฌ] 3DGS SH2RGB
Spherical Harmonics (SH)์ DC ์ฑ๋ถ
- Spherical Harmonics(SH)์ ๊ตฌ๋ฉด ์ขํ๊ณ์์ ์ ์๋๋ ํจ์์ ์งํฉ์ผ๋ก, 3D ๊ทธ๋ํฝ์ค์์ ์กฐ๋ช , ๋ฐ์ฌ, ์ํฅ ๋ฑ์ ํํํ๋๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
- SH ๊ณ์๋ SH Level์ ๋ฐ๋ผ DC ์ฑ๋ถ๊ณผ ๋๋จธ์ง ์ฑ๋ถ์ผ๋ก ๋๋ฉ๋๋ค.
- ์๋ ์ฝ๋์์ ์ค์ฌ์ฉ ์์๋ฅผ ๋จผ์ ์ดํดํด๋ด ์๋ค.
feateaures = concat([features_dc, features_rest], dim=1)
# (n_points, sh ๊ณ์, RGB)
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
@property
def get_features(self):
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
features_dc # (n_points, 1, RGB)
features_rest # (n_points, 15, RGB)
featrues_dc
์features_rest
๋ฅผ sh ๊ณ์ ์ฐจ์์ผ๋ก concat ํฉ๋๋ค.
RGB 3 channel๋ง๋ค n๋ฒ์งธ sh ๊ณ์์ ๋ํ ์ฑ๋์ด ์กด์ฌํฉ๋๋ค.
features_dc[:, 0, 0] = n_points, R channel, sh 0๋ฒ์งธ ๊ณ์
,"f_dc_0"
์ผ๋ก ๋ณ์์ด๋ฆ ์ ์features_dc[:, 0, 1] = n_points, G channel, sh 0๋ฒ์งธ ๊ณ์
,"f_dc_1"
์ผ๋ก ๋ณ์์ด๋ฆ ์ ์features_dc[:, 0, 2] = n_points, B channel, sh 0๋ฒ์งธ ๊ณ์
,"f_dc_2"
์ผ๋ก ๋ณ์์ด๋ฆ ์ ์features_dc
์features_extra
๋ฅผ ๋ถ๋ฌ์ฌ ๋๋ shape์ด(n_points, RGB, sh ๊ณ์)
๋ก ์ ์๋ฉ๋๋ค.- ํ์ง๋ง ํ์ต์์ ์ฌ์ฉ๋ ๋,
features_dc
,features_extra
๋ชจ๋transpose(1, 2)
ํ์ฌself_features_dc
,self._features_rest
๋ก ์ ์ํฉ๋๋ค. self_features_dc
,self._features_rest
์ shape์(n_points, sh ๊ณ์, RGB)
๋ก ๋ฐ๋๋๋ค.
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
def load_ply(self, path):
...
features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
...
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
...
(n_points, sh ๊ณ์, RGB)
๋ก ์ฐจ์์ด ๋ฐ๋self._features_dc
,self._features_rest
๋sh ๊ณ์
์ฐจ์์ธdim=1
์์concat
ํฉ๋๋ค.
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
@property
def get_features(self):
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
f_dc, self._features_dc
, f_rest, self._features_rest
๋ (n_points, sh ๊ณ์, RGB)
ํํ๋ก ํ์ต์ ์ฌ์ฉ๋ฉ๋๋ค.
self._features_dc
,self._features_rest
๋ฅผ ์ ์ํ๋ ๋ถ๋ถ์ ๋ด ์๋ค.-
random_initilization
์ผ๋ก pcd๋ฅผ 100,000๊ฐ ์ ์ํ์์ ๋,features.shape # (n_points, RGB, sh ๊ณ์) = (100000, 3, 16)
-
self._features_dc
๋features[:,:,0:1].transpose(1, 2) # (n_points, RGB, sh 0๋ฒ์งธ ๊ณ์).transpose(1, 2) -> (n_points, sh 0๋ฒ์งธ ๊ณ์, RGB) = (100000, 1, 3)
-
self._features_rest
๋features[:,:,1:].transpose(1, 2) # (n_points, RGB, sh 1~15๋ฒ์งธ ๊ณ์).transpose(1, 2) -> (n_points, sh 1~15๋ฒ์งธ ๊ณ์, RGB) = (100000, 15, 3)
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
self.spatial_lr_scale = spatial_lr_scale
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
features[:, :3, 0 ] = fused_color
features[:, 3:, 1:] = 0.0
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
rots[:, 0] = 1
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
...
์ต์ข
์ ์ผ๋ก self._features_dc
๊ณผ self._features_rest
๋ sh ๊ณ์
์ฑ๋์์ concatํ์ฌ ํ์ต์ ์ฌ์ฉ๋ฉ๋๋ค.
class GaussianModel:
...
@property
def get_features(self):
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
-
features_dc.shape # (n_points, sh 0๋ฒ์งธ ๊ณ์, RGB) = (100000, 1, 3)
-
features_rest # (n_points, sh 1~15๋ฒ์งธ ๊ณ์, RGB) = (100000, 15, 3)
-
torch.cat((features_dc, features_rest), dim=1) # (n_points, sh 0~16๋ฒ์งธ ๊ณ์, RBG) = (100000, 16, 3)
ํ์ต๋ self._features_dc
, self._features_rest
๋ฅผ ์ ์ฅํ ๋๋ ๋ค์ # (n_points, RGB, sh ๊ณ์)๋ก shape์ ๋ง๋ค์ด ์ ์ฅํฉ๋๋ค.
# 3dgs/scene/gaussian_model.py
class GaussianModel:
...
def save_ply(self, path):
mkdir_p(os.path.dirname(path))
xyz = self._xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = self._opacity.detach().cpu().numpy()
scale = self._scaling.detach().cpu().numpy()
rotation = self._rotation.detach().cpu().numpy()
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)
initial self._features_dc # (n_points, sh 0๋ฒ์งธ ๊ณ์, RGB) = (100000, 1, 3)
self._features_dc # (optimized n_points, sh 0๋ฒ์งธ ๊ณ์, RGB) = (317737, 1 3)
self._features_dc.detach().transpose(1, 2) # (optimized n_points, RGB, sh 0๋ฒ์งธ ๊ณ์) = (317737, 3, 1)
self._features_dc.detach().transpose(1, 2).flatten(start_dim=1) # (optimized n_points, RGB * sh 0๋ฒ์งธ ๊ณ์) = (317737, 3)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
์ด๋ฏ๋ก
initial self._features_rest # (n_points, sh 1~15๋ฒ์งธ ๊ณ์, RGB) = (100000, 15, 3)
self._features_rest # (optimized n_points, sh 1~15๋ฒ์งธ ๊ณ์, RGB) = (317737, 15 3)
self._features_rest.detach().transpose(1, 2) # (optimized n_points, RGB, sh 1~15๋ฒ์งธ ๊ณ์) = (317737, 3, 15)
self._features_rest.detach().transpose(1, 2).flatten(start_dim=1) # (optimized n_points, RGB * sh 1~15๋ฒ์งธ ๊ณ์) = (317737, 45)
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
์ด๋ฏ๋ก
DC ์ฑ๋ถ์ด๋?
- DC ์ฑ๋ถ (Direct Current ์ฑ๋ถ): SH ํํ์ ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ์ฃผํ์ ์ฑ๋ถ์ผ๋ก, ์์ ํจ์์ ๋๋ค. ์ด๋ ๊ตฌ๋ฉด์์ ๋ชจ๋ ๋ฐฉํฅ์ ๋ํด ์ผ์ ํ ๊ฐ์ ๊ฐ์ง๋ฉฐ, 0์ฐจ ๊ตฌ๋ฉด ์กฐํ ํจ์๋ก ๋ณผ ์ ์์ต๋๋ค.
- ์ญํ :
- ๊ธฐ๋ณธ ์กฐ๋ช ํํ: ๋ฐฉํฅ์ ๊ด๊ณ์์ด ์ผ์ ํ ์กฐ๋ช ์ ๋ํ๋ ๋๋ค.
- ๋ฒ ์ด์ค๋ผ์ธ ์ค์ : ๊ณ ์ฐจ ๊ตฌ๋ฉด ์กฐํ ํจ์์ ํจ๊ป ์ฌ์ฉ๋์ด ๋ฐฉํฅ์ ๋ฐ๋ฅธ ์กฐ๋ช ๋ณํ๋ฅผ ๋ ์ ํํํ ์ ์๋๋ก ํฉ๋๋ค.
SH ๊ณ์์ ๊ตฌ์กฐ
๊ฐ SH Level๋ง๋ค DC ์ฑ๋ถ๊ณผ rest ์ฑ๋ถ์ด ์กด์ฌํ๋ฉฐ, RGB ์ฑ๋๋ณ๋ก ๊ฐ๊ฐ์ ๊ณ์๊ฐ ๋ ๋ฆฝ์ ์ผ๋ก ์กด์ฌํฉ๋๋ค.
Level 0
- DC ์ฑ๋ถ: 1๊ฐ (RGB ๊ฐ๊ฐ 1๊ฐ์ฉ ์ด 3๊ฐ)
- Rest ์ฑ๋ถ: ์์
- ์ด ๊ณ์: 1๊ฐ์ DC ์ฑ๋ถ x 3 (RGB ์ฑ๋) = 3๊ฐ
Level 1
- DC ์ฑ๋ถ: 1๊ฐ (RGB ๊ฐ๊ฐ 1๊ฐ์ฉ ์ด 3๊ฐ)
- Rest ์ฑ๋ถ: 3๊ฐ (๊ฐ RGB ์ฑ๋๋ง๋ค 3๊ฐ์ ์ถ๊ฐ ๊ณ์)
- ์ด ๊ณ์: 1๊ฐ์ DC ์ฑ๋ถ + 3๊ฐ์ rest ์ฑ๋ถ = 4๊ฐ
- RGB ๊ฐ๊ฐ์ ๋ํด 4๊ฐ์ ๊ณ์ = 4 x 3 = 12๊ฐ
Level 2
- DC ์ฑ๋ถ: 1๊ฐ (RGB ๊ฐ๊ฐ 1๊ฐ์ฉ ์ด 3๊ฐ)
- Rest ์ฑ๋ถ: 8๊ฐ (๊ฐ RGB ์ฑ๋๋ง๋ค 8๊ฐ์ ์ถ๊ฐ ๊ณ์)
- ์ด ๊ณ์: 1๊ฐ์ DC ์ฑ๋ถ + 8๊ฐ์ rest ์ฑ๋ถ = 9๊ฐ
- RGB ๊ฐ๊ฐ์ ๋ํด 9๊ฐ์ ๊ณ์ = 9 x 3 = 27๊ฐ
Level 3
- DC ์ฑ๋ถ: 1๊ฐ (RGB ๊ฐ๊ฐ 1๊ฐ์ฉ ์ด 3๊ฐ)
- Rest ์ฑ๋ถ: 15๊ฐ (๊ฐ RGB ์ฑ๋๋ง๋ค 15๊ฐ์ ์ถ๊ฐ ๊ณ์)
- ์ด ๊ณ์: 1๊ฐ์ DC ์ฑ๋ถ + 15๊ฐ์ rest ์ฑ๋ถ = 16๊ฐ
- RGB ๊ฐ๊ฐ์ ๋ํด 16๊ฐ์ ๊ณ์ = 16 x 3 = 48๊ฐ
์์ฝ
- DC ์ฑ๋ถ: ๊ธฐ๋ณธ์ ์ธ SH ์ฑ๋ถ์ผ๋ก, ๋ชจ๋ ๋ฐฉํฅ์์ ์ผ์ ํ ๊ฐ์ ๊ฐ์ง.
- Rest ์ฑ๋ถ: DC ์ฑ๋ถ์ ์ ์ธํ ๋๋จธ์ง ๊ณ ์ฐจ SH ์ฑ๋ถ๋ค.
- SH Levels์ ๊ณ์: ๊ฐ Level์ ๋ฐ๋ผ DC ์ฑ๋ถ์ ํฌํจํ ์ด ๊ณ์์ ์๊ฐ ๊ฒฐ์ ๋จ.
- RGB ์ฑ๋๋ณ ๋ ๋ฆฝ์ ๊ณ์: ๊ฐ RGB ์ฑ๋์ ๋ํด SH ๊ณ์๋ค์ด ๊ฐ๋ณ์ ์ผ๋ก ์กด์ฌํ์ฌ ์ปฌ๋ฌ ์ ๋ณด๋ฅผ ์ ํํ ํํํจ.
open3d mesh๋ฅผ ๋ถ๋ฌ์์ initializeํ๋ ๋ฒ์ ์์๋ด ์๋ค.
- surface_mesh_to_bind๋ o3d mesh์ด๊ณ
- n_points = surface_mesh_to_bind์ triangle ์ * triangle๋น gaussian ์ ์ ๋๋ค.
- ์ฆ, n_points๋ triangle๋ค ์์ ์ ์ํ gaussian๋ค์ ์ด ๊ฐ์๋ฅผ ์๋ฏธํฉ๋๋ค.
@property
def n_points(self):
if not self.binded_to_surface_mesh:
return len(self._points)
else:
return self._n_points
n_points์์ spherical harmonics (sh)์ dc์ rest๋ฅผ ์ ์ํ๋ ์ฝ๋๋ฅผ ๋ด ์๋ค.
colors # shape (n_vertices, n_coords)
์ ๋๋ค.- ์ด๋
n_coords
๋vertices์ color ์ ๋ณด
์ ๋๋ค!! ์ ํํ ์ฃผ์์ ํํํ๋ฉดn_coords
๊ฐ ์๋๋ผrgb
๋ก ํ์ํด์ผ ํฉ๋๋ค. -
ํ์ง๋ง
colors
์ rgb 3์ฐจ์๊ณผ,vertices
์ xyz 3์ฐจ์ ์ ๋ณด๊ฐ ์ฐจ์์ด ๊ฐ์์ ์ฃผ์๋n_coords
๋ก ํต์ผํ ๊ฒ์ผ๋ก ๋ณด์ ๋๋ค.colors # shape (n_vertices, n_coords) <-- vertices์ ๋ํ rgb 3์ฐจ์ color ๊ฐ = n_coords๋ก ํํ vertices # shape (n_vertices, n_coords) <-- vertices์ ๋ํ xyz 3์ฐจ์ ์ขํ๊ฐ = n_coords๋ก ํํ
colors
์์sh_coordinates_dc
๋ฅผ ๋ง๋๋ ๊ณผ์ ์ ๋ด ์๋ค.
def RGB2SH(rgb):
return (rgb - 0.5) / C0
def SH2RGB(sh):
return sh * C0 + 0.5
sh_coordinates_dc = RGB2SH(colors).unsqueeze(dim=1) # shape (n_vertices, 1, n_coords)
- ์ฃผ์์ ์ ๋๋ก ์ฐ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
sh_coordinates_dc = RGB2SH(colors).unsqueeze(dim=1) # shape (n_vertices, 1, rgb)
-
์ฆ,
๋ง์ง๋ง rgb 3๊ฐ์ ์ฐจ์
์ ๋ํด,sh_coordiantes_dc
๋ ๊ฐ๊ฐ sh๋ฅผ 1๊ฐ์ฉ ๊ฐ์ง๊ฒ ๋ฉ๋๋ค. self._sh_coordinates_rest
๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.self._sh_coordinates_rest = torch.zeros(n_points, sh_levels**2 - 1, 3) # shape (n_points, sh_levels**2 - 1, 3)
- ์ฆ,
๋ง์ง๋ง rgb 3๊ฐ์ ์ฐจ์
์ ๋ํด,self._sh_coordinates_rest
๋ ๊ฐ๊ฐsh_levels**2-1 ๊ฐ
๋งํผ์ ๊ฐ์ง๊ฒ ๋ฉ๋๋ค. - ์ฆ,
rgb ๊ฐ ์ฑ๋๋ณ
๋กsh ๊ณ์๊ฐ ํ ๋น
๋ฉ๋๋ค.
# Initialize color features
self.sh_levels = sh_levels
sh_coordinates_dc = RGB2SH(colors).unsqueeze(dim=1)
self._sh_coordinates_dc = nn.Parameter(
sh_coordinates_dc.to(self.nerfmodel.device),
requires_grad=True and (not freeze_gaussians)
).to(self.nerfmodel.device)
self._sh_coordinates_rest = nn.Parameter(
torch.zeros(n_points, sh_levels**2 - 1, 3).to(self.nerfmodel.device),
requires_grad=True and (not freeze_gaussians)
).to(self.nerfmodel.device)
# gaussian-splatting/scene/gaussian_model.py
class GaussianModel:
...
@property
def get_features(self):
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
3dgs์์ _features_dc
, _features_rest
๋ spherical harmonics์ ๊ณ์์ ํด๋ฑํฉ๋๋ค.
- ๋ฐ๋ผ์ ์ด๋ฅผ SuGaR์์๋ ๋ณ์์ด๋ฆ์ ๊ทธ๋ฅ
_sh_coordinates_dc[...]
,_sh_coordinates_rest[...]
๋ก ๋ถ๋ฌ์์ ์ฌ์ฉํฉ๋๋ค.
# SuGaR/sugar_extractors/coarse_mesh.py
def extract_mesh_from_coarse_sugar(args):
...
sugar._sh_coordinates_dc[...] = nerfmodel.gaussians._features_dc.detach()
sugar._sh_coordinates_rest[...] = nerfmodel.gaussians._features_rest.detach()
3dgs์์ _features_dc
๋ RGB์ ํด๋นํ๋ sh(spherical harmonics)๊ณ์์
๋๋ค.
- ๋ฐ๋ผ์ ์ด๋ฅผ SuGaR์์๋
SH2RGB
ํจ์๋ก sh๋ฅผ rgb๋ก ๋ณํํ์ฌcolors
๋ณ์์ ํ ๋นํฉ๋๋ค.
# SuGaR/sugar_extractors/coarse_mesh.py
def extract_mesh_from_coarse_sugar(args):
...
CONSOLE.print(f"\nLoading the coarse SuGaR model from path {sugar_checkpoint_path}...")
checkpoint = torch.load(sugar_checkpoint_path, map_location=nerfmodel.device)
colors = SH2RGB(checkpoint['state_dict']['_sh_coordinates_dc'][:, 0, :])
sugar = SuGaR(
nerfmodel=nerfmodel,
points=checkpoint['state_dict']['_points'],
colors=colors,
initialize=True,
sh_levels=nerfmodel.gaussians.active_sh_degree+1,
keep_track_of_knn=True,
knn_to_track=16,
beta_mode='average', # 'learnable', 'average', 'weighted_average'
primitive_types='diamond', # 'diamond', 'square'
surface_mesh_to_bind=None, # Open3D mesh
)
sugar.load_state_dict(checkpoint['state_dict'])
sugar.eval()
- ๊ทธ๋ฆฌ๊ณ ์ด
colors
๋ ๋ถ๋ฌ์์RGB2SH
ํจ์๋ก ๋ณํํ๊ณunsqueeze
ํ์ฌ ์ฐจ์์ ๋ง์ถฐ ํ์ต์ initialize๋ก ์ค ์ ์์ต๋๋ค.
SH ๊ณ์๋ฅผ RGB๋ก ๋ณํํ๊ธฐ ์ํด ํ์ํ ์ ๋ณด
SH ๊ณ์๋ฅผ RGB๋ก ๋ณํํ๊ธฐ ์ํด์๋ ๋ค์์ ์ ๋ณด๊ฐ ํ์ํฉ๋๋ค: sh์ deg
, sh์ ๊ณ์
, camera center์์ point๊น์ง์ direction
.
- SuGaR์์๋
eval_sh(deg, sh, dirs)
๋ฅผ ์ฌ์ฉํ์ฌ ํน์ ์นด๋ฉ๋ผ ๋ฐฉํฅ์์ ํฌ์ธํธ๊น์ง์ ๋ ๋๋ง ๋ฐฉํฅ์ ๋ฐ๋ผ SH ๊ณ์๋ฅผ ํ๋์ RGB ์ปฌ๋ฌ๋ก ๋ณํํฉ๋๋ค. - RGB 3๊ฐ์ ์ฑ๋๋ง๋ค SH ๊ณ์๊ฐ ๋ ๋ฆฝ์ ์ผ๋ก ์กด์ฌํฉ๋๋ค.
- DC๋ ์กฐ๋ช ์ ์ ์ฒด์ ์ธ ๋ฐ๊ธฐ๋ฅผ ๋ํ๋ด๋ ์์์ ๋๋ค.
SH์ DC (0๋ฒ์งธ Band)
sh[..., 0]
์ 0๋ฒ์งธ deg์ ํด๋นํ๋ SH์ DC ๊ณ์๋ก, ๋ชจ๋ RGB ์ฑ๋์ ๋ํด ๋์ผํฉ๋๋ค.
SH์ Rest (1๋ฒ์งธ, 2๋ฒ์งธ, 3๋ฒ์งธ, โฆ)
sh[..., 1]
๋ถํฐsh[..., 24]
๊น์ง์ ๊ณ์๋ ๊ฐ๋๊ฐ ์๋ ์กฐ๋ช ๊ตฌ์ฑ ์์์ ๋๋ค.- SH ๊ณ์๋ ๊ฐ RGB ์ฑ๋์ ๋ํด ๊ฐ๋ณ์ ์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค. ์ฆ, R, G, B ๊ฐ ์ฑ๋์ ์๋ก ๋ค๋ฅธ SH ๊ณ์๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค.
์์๋ก ์ค๋ช
์๋ฅผ ๋ค์ด, RGB ์ฑ๋ ๊ฐ๊ฐ์ ๋ํด SH ๊ณ์๊ฐ ๋ค์๊ณผ ๊ฐ์ด ์์ ์ ์์ต๋๋ค:
R ์ฑ๋
sh[..., 0] = 0.5
sh[..., 1] = 0.1
sh[..., 2] = 0.3
...
G ์ฑ๋
sh[..., 0] = 0.4
sh[..., 1] = 0.2
sh[..., 2] = 0.6
...
B ์ฑ๋
sh[..., 0] = 0.7
sh[..., 1] = 0.3
sh[..., 2] = 0.5
...
๊ฐ ๊ณ์๋ ๊ฐ RGB ์ฑ๋์ ๋ํด ๊ฐ๋ณ์ ์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค:
- sh[โฆ, 1]์ 1๋ฒ์งธ deg์ ํด๋นํ๋ SH์ ๊ณ์๋ก, ๊ฐ RGB ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋ ๊ฐ์ ๊ฐ์ง๋๋ค.
- sh[โฆ, 2]์ 2๋ฒ์งธ deg์ ํด๋นํ๋ SH์ ๊ณ์๋ก, ๊ฐ RGB ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋ ๊ฐ์ ๊ฐ์ง๋๋ค.
- sh[โฆ, 3]์ 3๋ฒ์งธ deg์ ํด๋นํ๋ SH์ ๊ณ์๋ก, ๊ฐ RGB ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋ ๊ฐ์ ๊ฐ์ง๋๋ค.
- โฆ
- sh[โฆ, 24]์ 24๋ฒ์งธ deg์ ํด๋นํ๋ SH์ ๊ณ์๋ก, ๊ฐ RGB ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋ ๊ฐ์ ๊ฐ์ง๋๋ค.
์์ฝ
- SH ๊ณ์๋ ๊ฐ RGB ์ฑ๋์ ๋ํด ๊ฐ๋ณ์ ์ผ๋ก ์กด์ฌํ๊ณ , ๊ณ์ฐ๋ฉ๋๋ค.
- ์ด๋ ๋์ผํ ๊ฐ์ ๊ฐ์ง๋ค๋ ์๋ฏธ๊ฐ ์๋๋ฉฐ, ๊ฐ ์ฑ๋๋ณ๋ก ๋ค๋ฅธ ๊ฐ์ ๊ฐ์ง ์ ์์ต๋๋ค.
- ๊ฐ ์ฑ๋์ SH ๊ณ์๋ ๊ฐ๋ณ์ ์ผ๋ก ์ฒ๋ฆฌ๋๋ฉฐ, eval_sh ํจ์๋ ๊ฐ ์ฑ๋์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก SH ๊ณ์๋ฅผ ์ฌ์ฉํ์ฌ RGB ๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
# SuGaR/sugar_scene/sugar_model.py
def get_points_rgb(
self,
positions:torch.Tensor=None,
camera_centers:torch.Tensor=None,
directions:torch.Tensor=None,
sh_levels:int=None,
sh_coordinates:torch.Tensor=None,
):
"""Returns the RGB color of the points for the given camera pose.
Args:
positions (torch.Tensor, optional): Shape (n_pts, 3). Defaults to None.
camera_centers (torch.Tensor, optional): Shape (n_pts, 3) or (1, 3). Defaults to None.
directions (torch.Tensor, optional): _description_. Defaults to None.
Raises:
ValueError: _description_
Returns:
_type_: _description_
"""
if positions is None:
positions = self.points
if camera_centers is not None:
render_directions = torch.nn.functional.normalize(positions - camera_centers, dim=-1)
elif directions is not None:
render_directions = directions
else:
raise ValueError("Either camera_centers or directions must be provided.")
if sh_coordinates is None:
sh_coordinates = self.sh_coordinates
if sh_levels is None:
sh_coordinates = sh_coordinates
else:
sh_coordinates = sh_coordinates[:, :sh_levels**2]
shs_view = sh_coordinates.transpose(-1, -2).view(-1, 3, sh_levels**2)
sh2rgb = eval_sh(sh_levels-1, shs_view, render_directions)
colors = torch.clamp_min(sh2rgb + 0.5, 0.0).view(-1, 3)
return colors
# SuGaR/sugar_utils/spherical_harmonics.py
def eval_sh(deg, sh, dirs):
"""
Evaluate spherical harmonics at unit directions
using hardcoded SH polynomials.
Works with torch/np/jnp.
... Can be 0 or more batch dimensions.
Args:
deg: int SH deg. Currently, 0-3 supported
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
dirs: jnp.ndarray unit directions [..., 3]
Returns:
[..., C]
"""
assert deg <= 4 and deg >= 0
coeff = (deg + 1) ** 2
assert sh.shape[-1] >= coeff
result = C0 * sh[..., 0]
if deg > 0:
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
result = (result -
C1 * y * sh[..., 1] +
C1 * z * sh[..., 2] -
C1 * x * sh[..., 3])
if deg > 1:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result = (result +
C2[0] * xy * sh[..., 4] +
C2[1] * yz * sh[..., 5] +
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
C2[3] * xz * sh[..., 7] +
C2[4] * (xx - yy) * sh[..., 8])
if deg > 2:
result = (result +
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
C3[1] * xy * z * sh[..., 10] +
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
C3[5] * z * (xx - yy) * sh[..., 14] +
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
if deg > 3:
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
return result
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
1.0925484305920792,
-1.0925484305920792,
0.31539156525252005,
-1.0925484305920792,
0.5462742152960396
]
C3 = [
-0.5900435899266435,
2.890611442640554,
-0.4570457994644658,
0.3731763325901154,
-0.4570457994644658,
1.445305721320277,
-0.5900435899266435
]
C4 = [
2.5033429417967046,
-1.7701307697799304,
0.9461746957575601,
-0.6690465435572892,
0.10578554691520431,
-0.6690465435572892,
0.47308734787878004,
-1.7701307697799304,
0.6258357354491761,
]
SH2RGB์ RGB2SH๋ sh์ 0๋ฒ์งธ band์ธ dc ์ฑ๋ถ์์๋ง ์ฌ์ฉ๋๋ ํจ์์ ๋๋ค.
def RGB2SH(rgb):
return (rgb - 0.5) / C0
def SH2RGB(sh):
return sh * C0 + 0.5
RGB2SH ํจ์
sh_coordinates_dc = RGB2SH(colors).unsqueeze(dim=1)
RGB2SH ํจ์
๋RGB ๊ฐ์ ๋ฐ์์ SH ๊ณ์๋ก ๋ณํ
ํฉ๋๋ค.๋ณํ๋ SH ๊ณ์๋ DC ์ฑ๋ถ์ ํด๋น
ํฉ๋๋ค.RGB2SH(colors)
:colors๋ RGB ๊ฐ
์ ๋ํ๋ด๋ฉฐ,DC ์ฑ๋ถ์ ํด๋น
ํฉ๋๋ค.- ๋ฐ๋ผ์, ์ด ๋ณํ์
RGB ๊ฐ์ DC ์ฑ๋ถ์ SH ๊ณ์๋ก ๋ณํ
ํฉ๋๋ค. .unsqueeze(dim=1)
: ์ฐจ์์ ์ถ๊ฐํ์ฌ SH ๊ณ์์ DC ์ฑ๋ถ์ 3D ํ ์๋ก ๋ง๋ญ๋๋ค.
SH2RGB ํจ์
refined_sugar = SuGaR(
nerfmodel=nerfmodel,
points=checkpoint['state_dict']['_points'],
colors=SH2RGB(checkpoint['state_dict']['_sh_coordinates_dc'][:, 0, :]),
initialize=False,
sh_levels=nerfmodel.gaussians.active_sh_degree+1,
keep_track_of_knn=False,
knn_to_track=0,
beta_mode='average',
surface_mesh_to_bind=o3d_mesh,
n_gaussians_per_surface_triangle=n_gaussians_per_surface_triangle,
)
refined_sugar.load_state_dict(checkpoint['state_dict'])
textures_uv = TexturesUV(
maps=SH2RGB(self.texture_features[..., 0, :][None]), #texture_img[None]),
verts_uvs=self.verts_uv[None],
faces_uvs=self.faces_uv[None],
sampling_mode='nearest',
)
SH2RGB ํจ์
๋SH ๊ณ์๋ฅผ ๋ฐ์์ RGB ๊ฐ์ผ๋ก ๋ณํ
ํฉ๋๋ค.- ์ด ํจ์๋
DC ์ฑ๋ถ์ ํด๋นํ๋ SH ๊ณ์๋ฅผ ๋ณํ
ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. self.texture_features[..., 0, :]
์ SH ๊ณ์์ ์ฒซ ๋ฒ์งธ ์ฑ๋ถ(DC ์ฑ๋ถ)์ ์ ํํฉ๋๋ค.[None]
์ ์ฌ์ฉํ์ฌ ์ฐจ์์ ์ถ๊ฐํฉ๋๋ค.SH2RGB ํจ์
๋ ์ดDC ์ฑ๋ถ์ RGB ๊ฐ์ผ๋ก ๋ณํ
ํฉ๋๋ค.
๊ฒฐ๋ก
RGB2SH
์ SH2RGB
๋ ์ฃผ์ด์ง ์ฝ๋์์ SH ๊ณ์์ DC ์ฑ๋ถ์ ํด๋นํ๋ ๊ฐ์ ๋ณํ
ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ๋ฐ๋ผ์, ์ด ๋ ํจ์๋ DC ์ฑ๋ถ์๋ง ํด๋น
ํ๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค.
์ต์ข ์์ฝ
RGB2SH ํจ์
๋DC ์ฑ๋ถ์ ํด๋นํ๋ RGB ๊ฐ์ SH ๊ณ์๋ก ๋ณํ
ํฉ๋๋ค.SH2RGB ํจ์
๋DC ์ฑ๋ถ์ ํด๋นํ๋ SH ๊ณ์๋ฅผ RGB ๊ฐ์ผ๋ก ๋ณํ
ํฉ๋๋ค.- ์ฃผ์ด์ง ์ฝ๋์์๋ ์ด ๋ ํจ์๊ฐ SH ๊ณ์์ DC ์ฑ๋ถ์ ์ฃผ๋ก ์ฌ์ฉ๋๊ณ ์์ต๋๋ค.
[3D CV ์ฐ๊ตฌ] 3DGS input & output .ply properties & Meshlab Vert & Spherical Harmonics (SH) & Mesh
Leave a comment