[Mesh] How Gaussians turn into Primitives in SuGaR model
SuGaR์์ Gaussians์ด ์ด๋ป๊ฒ Primitives (diamond ๋๋ square(=quad))๋ก ๋ณํ๋๋์ง ์์๋ด ์๋ค.
SuGaR ๋ชจ๋ธ์์๋ Gaussian splats๋ฅผ ๋ ์ ๋ฐํ ๋ค๊ฐํ(primitive polygon) ํํ๋ก ๋์ฒดํ์ฌ 3D ํ๋ฉด์ ๋ณด๋ค ์ ํํ๊ฒ ํํํฉ๋๋ค.
Gaussian splats๊ฐ ๋ค์ด์๋ชฌ๋(diamond) ๋๋ ์ฌ๊ฐํ(square) ํํ์ ๋ค๊ฐํ์ผ๋ก ๋ณํ๋๋ ๊ณผ์ ์ ์์๋ด ์๋ค.
๋ค์ด์๋ชฌ๋์ ์ฌ๊ฐํ์ ์ฐจ์ด์
๋ค์ด์๋ชฌ๋์ ์ฌ๊ฐํ์ ์ ์ ์ ๊ฐ์๋ ๊ฐ์ง๋ง, ์ ์ ์ ๋ฐฐ์ด ๋ฐฉ์๊ณผ ๊ทธ์ ๋ฐ๋ฅธ ๊ตฌ์กฐ๊ฐ ๋ค๋ฆ ๋๋ค. ๋ค์ด์๋ชฌ๋๋ ์ค์ฌ์ ๊ธฐ์ค์ผ๋ก ์์๋, ์ข์ฐ๋ก ๋ฐฐ์ด๋ ์ ์ฒด์ ํํ๋ฅผ ๊ฐ์ง๋ฉฐ, ์ฌ๊ฐํ์ ํ๋ฉด ์์์ ๋ค ๋ฐฉํฅ์ผ๋ก ๋ฐฐ์ด๋ ํ๋ฉด์ ํํ๋ฅผ ๊ฐ์ง๋๋ค.
์ฌ๊ธฐ์ ์ฌ์ฉ๋ square๋ ์ผ๋ฐ์ ์ผ๋ก 3D ๊ทธ๋ํฝ์ค์์ โquadโ๋ผ๊ณ ๋ถ๋ฆฌ๋ ์ฌ๊ฐํ์ ์๋ฏธํฉ๋๋ค. quad๋ ๋ค ๊ฐ์ ์ ์ ์ ๊ฐ์ง ๋ค๊ฐํ์ ์๋ฏธํ๋ฉฐ, ์ด ์ฝ๋์์๋ Gaussian splats๋ฅผ ๋์ฒดํ๊ธฐ ์ํด ์ฌ์ฉ๋ฉ๋๋ค.
sugar_model.py
์์ diamond์ sqaure primitive ์ ์๋ ๋ค์๊ณผ ๊ฐ์ด ํฉ๋๋ค.
# Primitive polygon that will be used to replace the gaussians
self.primitive_types = primitive_types
self._diamond_verts = torch.Tensor(
[[0., -1., 0.], [0., 0, 1.],
[0., 1., 0.], [0., 0., -1.]]
).to(nerfmodel.device)
self._square_verts = torch.Tensor(
[[0., -1., 1.], [0., 1., 1.],
[0., 1., -1.], [0., -1., -1.]]
).to(nerfmodel.device)
if primitive_types == 'diamond':
self.primitive_verts = self._diamond_verts # Shape (n_vertices_per_gaussian, 3)
elif primitive_types == 'square':
self.primitive_verts = self._square_verts # Shape (n_vertices_per_gaussian, 3)
self.primitive_triangles = torch.Tensor(
[[0, 2, 1], [0, 3, 2]]
).to(nerfmodel.device).long() # Shape (n_triangles_per_gaussian, 3)
self.primitive_border_edges = torch.Tensor(
[[0, 1], [1, 2], [2, 3], [3, 0]]
).to(nerfmodel.device).long() # Shape (n_edges_per_gaussian, 2)
self.n_vertices_per_gaussian = len(self.primitive_verts)
self.n_triangles_per_gaussian = len(self.primitive_triangles)
self.n_border_edges_per_gaussian = len(self.primitive_border_edges)
self.triangle_scale = triangle_scale
์ด๋ฅผ visualizeํ๋ ์ฝ๋๋ฅผ ์์ฑํ์ฌ ๋ด ์๋ค.
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
# Convert tensors to numpy arrays
_diamond_verts = np.array([[0., -1., 0.], [0., 0, 1.], [0., 1., 0.], [0., 0., -1.]])
_square_verts = np.array([[0., -1., 1.], [0., 1., 1.], [0., 1., -1.], [0., -1., -1.]])
primitive_triangles = np.array([[0, 2, 1], [0, 3, 2]])
primitive_border_edges = np.array([[0, 1], [1, 2], [2, 3], [3, 0]])
# Function to plot vertices
def plot_vertices(vertices, ax, color='b'):
ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], color=color, s=100)
# Function to plot edges
def plot_edges(vertices, edges, ax, color='r'):
for edge in edges:
edge_verts = vertices[edge]
ax.plot(edge_verts[:, 0], edge_verts[:, 1], edge_verts[:, 2], color=color)
# Function to plot triangles
def plot_triangles(vertices, triangles, ax, color='g'):
for tri in triangles:
tri_verts = vertices[tri]
poly3d = [[tri_verts[0], tri_verts[1], tri_verts[2]]]
ax.add_collection3d(Poly3DCollection(poly3d, facecolors=color, linewidths=1, edgecolors='k', alpha=.25))
# Plotting
fig = plt.figure(figsize=(18, 6))
# Plot diamond vertices and edges
ax1 = fig.add_subplot(131, projection='3d')
plot_vertices(_diamond_verts, ax1, color='b')
plot_edges(_diamond_verts, primitive_border_edges, ax1, color='r')
plot_triangles(_diamond_verts, primitive_triangles, ax1, color='g')
ax1.set_title('Diamond')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
# Plot square vertices and edges
ax2 = fig.add_subplot(132, projection='3d')
plot_vertices(_square_verts, ax2, color='b')
plot_edges(_square_verts, primitive_border_edges, ax2, color='r')
plot_triangles(_square_verts, primitive_triangles, ax2, color='g')
ax2.set_title('Square')
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
# Separate plot for _diamond_verts and _square_verts visualization
ax3 = fig.add_subplot(133, projection='3d')
plot_vertices(_diamond_verts, ax3, color='b')
plot_vertices(_square_verts, ax3, color='r')
ax3.set_title('Diamond (blue) and Square (red) Vertices')
ax3.set_xlabel('X')
ax3.set_ylabel('Y')
ax3.set_zlabel('Z')
plt.show()
์์ฝ
- ๋ค์ด์๋ชฌ๋์ ์ฌ๊ฐํ์ ์ ์ ๋ฐฐ์ด ์์:
- ๋ค์ด์๋ชฌ๋๋ ์ค์์ ๊ธฐ์ค์ผ๋ก ์์๋, ์ข์ฐ๋ก ๋ฐฐ์ด๋ฉ๋๋ค.
- ์ฌ๊ฐํ์ ํ๋ฉด ์์์ ๋ค ๋ฐฉํฅ์ผ๋ก ๋ฐฐ์ด๋ฉ๋๋ค.
- ํ์ฑ๋ ๊ตฌ์กฐ:
- ๋ค์ด์๋ชฌ๋๋ ์ ์ฒด์ ์ธ ๊ตฌ์กฐ๋ฅผ ํ์ฑํฉ๋๋ค.
- ์ฌ๊ฐํ(quad)์ ํ๋ฉด์ ์ธ ๊ตฌ์กฐ๋ฅผ ํ์ฑํฉ๋๋ค.
- ์ด ๋ ๊ฐ์ง ํํ๋ ์ ์ ์ ๋ฐฐ์ด ์์์ ๊ทธ์ ๋ฐ๋ฅธ ์ผ๊ฐํ์ ๋ฐฐ์น๊ฐ ๋ฌ๋ผ์ ์๋ก ๋ค๋ฅธ ๊ตฌ์กฐ๋ฅผ ํ์ฑํฉ๋๋ค. ์ด๋ฅผ ํตํด Gaussian splats๋ฅผ ๋ค์ํ ํํ๋ก ๋ณํํ์ฌ ๋ ์ ๋ฐํ๊ณ ๋ณต์กํ 3D ๋ชจ๋ธ์ ํํํ ์ ์์ต๋๋ค.
Leave a comment