4 minute read

# sugar_scene/sugar_model.py

class SuGaR(nn.Module):

...

        # ---Tools for future meshing---
        # Primitive polygon that will be used to replace the gaussians
        self.primitive_types = primitive_types
        self._diamond_verts = torch.Tensor(
                [[0., -1., 0.], [0., 0, 1.], 
                [0., 1., 0.], [0., 0., -1.]]
                ).to(nerfmodel.device)
        self._square_verts = torch.Tensor(
                [[0., -1., 1.], [0., 1., 1.], 
                [0., 1., -1.], [0., -1., -1.]]
                ).to(nerfmodel.device)
        if primitive_types == 'diamond':
            self.primitive_verts = self._diamond_verts  # Shape (n_vertices_per_gaussian, 3)
        elif primitive_types == 'square':
            self.primitive_verts = self._square_verts  # Shape (n_vertices_per_gaussian, 3)
        self.primitive_triangles = torch.Tensor(
            [[0, 2, 1], [0, 3, 2]]
            ).to(nerfmodel.device).long()  # Shape (n_triangles_per_gaussian, 3)
        self.primitive_border_edges = torch.Tensor(
            [[0, 1], [1, 2], [2, 3], [3, 0]]
            ).to(nerfmodel.device).long()  # Shape (n_edges_per_gaussian, 2)
        self.n_vertices_per_gaussian = len(self.primitive_verts)
        self.n_triangles_per_gaussian = len(self.primitive_triangles)
        self.n_border_edges_per_gaussian = len(self.primitive_border_edges)
        self.triangle_scale = triangle_scale

...

   @property
    def triangle_vertices(self):
        # Apply shift to triangle vertices
        if self.primitive_types == 'diamond':
            self.primitive_verts = self._diamond_verts
        elif self.primitive_types == 'square':
            self.primitive_verts = self._square_verts
        else:
            raise ValueError("Unknown primitive type: {}".format(self.primitive_types))
        triangle_vertices = self.primitive_verts[None]  # Shape: (1, n_vertices_per_gaussian, 3)
        
        # Move canonical, shifted triangles to the local gaussian space
        # We need to permute the scaling axes so that the smallest is the first
        scale_argsort = self.scaling.argsort(dim=-1)
        scale_argsort[..., 1] = (scale_argsort[..., 0] + 1) % 3
        scale_argsort[..., 2] = (scale_argsort[..., 0] + 2) % 3
        
        # TODO: Change for a lighter computation that does not require to compute the rotation matrices.
        # We can just permute the axes of triangle_vertices with the inverse permutation.
        
        # Permute scales
        scale_sort = self.scaling.gather(dim=1, index=scale_argsort)
        
        # Permute rotation axes
        rotation_matrices = quaternion_to_matrix(self.quaternions)
        rotation_sort = rotation_matrices.gather(dim=2, index=scale_argsort[..., None, :].expand(-1, 3, -1))
        quaternion_sort = matrix_to_quaternion(rotation_sort)
        
        triangle_vertices = self.points.unsqueeze(1) + quaternion_apply(
            quaternion_sort.unsqueeze(1),
            triangle_vertices * self.triangle_scale * scale_sort.unsqueeze(1))
        
        triangle_vertices = triangle_vertices.view(-1, 3)  # Shape: (n_pts * n_vertices_per_gaussian, 3)
        return triangle_vertices
    
    @property
    def triangle_border_edges(self):
        edges = self.primitive_border_edges[None]  # Shape: (1, n_border_edges_per_gaussian, 2)
        edges = edges + 4 * torch.arange(len(self.points), device=self.device)[:, None, None]  # Shape: (n_pts, n_border_edges_per_gaussian, 2)
        edges = edges.view(-1, 2)  # Shape: (n_pts * n_border_edges_per_gaussian, 2)
        return edges
    
    @property
    def triangles(self):
        triangles = self.primitive_triangles[None].expand(self.n_points, -1, -1).clone()  # Shape: (n_pts, n_triangles_per_gaussian, 3)
        triangles = triangles + 4 * torch.arange(len(self.points), device=self.device)[:, None, None]  # Shape: (n_pts, n_triangles_per_gaussian, 3)
        triangles = triangles.view(-1, 3)  # Shape: (n_pts * n_triangles_per_gaussian, 3)
        return triangles

  • ์œ„ ์ฝ”๋“œ์—์„œ primitive๊ฐ€ diamond ํ˜น์€ square๋กœ ๊ฒฐ์ •๋˜๋Š”๋ฐ, ์—ฌ๊ธฐ์„œ ๋‘ ์ผ€์ด์Šค ๋ชจ๋‘ ์ •์ ์ธ vertices์˜ ๊ฐœ์ˆ˜๋Š” 4๊ฐœ์ž…๋‹ˆ๋‹ค.
  • len(self.primitive_verts)๋กœ self.n_vertices_per_gaussian์œผ๋กœ ์„ค์ •ํ•˜๋Š”๋ฐ, ์ด ๋ง์€ gaussian๋‹น vertices๋ฅผ 4๊ฐœ๋กœ ์ •ํ•œ๋‹ค๋Š” ์˜๋ฏธ์ž…๋‹ˆ๋‹ค.
  • ์ฆ‰, Gaussian์˜ primitive representation์„ diamond ํ˜น์€ sqaure๋กœ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ  diamond ํ˜น์€ sqaure๋Š” top triangle๊ณผ bottom triangle 2๊ฐœ๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.
  • ๋”ฐ๋ผ์„œ diamond ํ˜น์€ sqaure์— ๋Œ€ํ•œ vertices, edges๋Š” triangle๋กœ ๋ชจ๋‘ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • triangle_vertices๋ฅผ ์ •์˜ํ•  ๋•Œ๋Š”, None์œผ๋กœ ๋ฐฐ์น˜์ฐจ์› 1์„ ์•ž์— ์ถ”๊ฐ€ํ•ด์ค๋‹ˆ๋‹ค.
  • ์ฆ‰ triangle_vertices๋Š” 1๊ฐœ์˜ triangle์— ๋Œ€ํ•˜์—ฌ vertices๋Š” 4๊ฐœ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๊ณ , ๊ฐ vertices๋Š” 3์ฐจ์› (x,y,z)๋ฅผ ๊ฐ€์ง€๋ฏ€๋กœ shape์€ (1, n_vertices_per_gaussian, 3)์—์„œ ๊ฒฐ๊ณผ์ ์œผ๋กœ (1, 4, 3)์ด ๋ฉ๋‹ˆ๋‹ค.
  • ๊ฒฐ๋ก ์ ์œผ๋กœ triangle 1๊ฐœ๋‹น local gaussian 1๊ฐœ๋ฅผ ํ• ๋‹นํ•˜๊ธฐ ์œ„ํ•œ ์ž‘์—…์ž…๋‹ˆ๋‹ค.
self.n_vertices_per_gaussian = len(self.primitive_verts)
...
triangle_vertices = self.primitive_verts[None]  # Shape: (1, n_vertices_per_gaussian, 3)
  • triangle๋‹น vertices๋ฅผ ์ •์˜ํ–ˆ์œผ๋‹ˆ, points ์ˆ˜๋งŒํผ edges์™€ triangles์„ ์ •์˜ํ•ด์ค๋‹ˆ๋‹ค.
  • edges ์†์„ฑ์€ Gaussian๊ณผ ์—ฐ๊ด€๋œ ๊ฐ primitive shape์˜ wireframe ๋˜๋Š” boundary๋ฅผ ์ •์˜ํ•˜๋Š” ๋ฐ ๋งค์šฐ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด edges๋Š” rendering ๋ชฉ์ ์ด๋‚˜ ๋„ํ˜• ํ‘œ๋ฉด์—์„œ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•˜๋Š” ๋ชจ๋“  geometric computations์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์š”์•ฝํ•˜์ž๋ฉด, edges ์†์„ฑ์€ Gaussian์„ ๋Œ€์ฒดํ•˜๋Š” ๋ชจ๋“  primitive shapes์˜ boundary edges๋ฅผ ํฌํ•จํ•˜๋Š” tensor๋ฅผ ๊ตฌ์„ฑํ•˜๊ณ  ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์€ ๊ฐ Gaussian์˜ primitive representation์—์„œ vertices ๊ฐ„์˜ ๊ตฌ์กฐ์™€ connections์„ ์ถ”์ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
    @property
    def triangle_border_edges(self):
        edges = self.primitive_border_edges[None]  # Shape: (1, n_border_edges_per_gaussian, 2)
        edges = edges + 4 * torch.arange(len(self.points), device=self.device)[:, None, None]  # Shape: (n_pts, n_border_edges_per_gaussian, 2)
        edges = edges.view(-1, 2)  # Shape: (n_pts * n_border_edges_per_gaussian, 2)
        return edges
    
    @property
    def triangles(self):
        triangles = self.primitive_triangles[None].expand(self.n_points, -1, -1).clone()  # Shape: (n_pts, n_triangles_per_gaussian, 3)
        triangles = triangles + 4 * torch.arange(len(self.points), device=self.device)[:, None, None]  # Shape: (n_pts, n_triangles_per_gaussian, 3)
        triangles = triangles.view(-1, 3)  # Shape: (n_pts * n_triangles_per_gaussian, 3)
        return triangles

TexturesUV์™€ TexturesVertex์˜ ์ฐจ์ด

TexturesUV:

  • UV ๋งคํ•‘ ๊ธฐ๋ฐ˜ ํ…์Šค์ฒ˜๋ง์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ํ…์Šค์ฒ˜ ๋งต๊ณผ vertices, faces์˜ UV ์ขŒํ‘œ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋Š” ํ…์Šค์ฒ˜ ๋งต์„ ๊ฐ ์‚ผ๊ฐํ˜•์— ๋งคํ•‘ํ•˜์—ฌ ๋ Œ๋”๋งํ•ฉ๋‹ˆ๋‹ค.
  • ์ฃผ๋กœ ์ „์ฒด ์ด๋ฏธ์ง€ ๊ธฐ๋ฐ˜ ํ…์Šค์ฒ˜๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ์— ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค.

TexturesVertex:

  • Vertex ๊ธฐ๋ฐ˜ ํ…์Šค์ฒ˜๋ง์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ๊ฐ vertex๋งˆ๋‹ค ํ…์Šค์ฒ˜ ์ƒ‰์ƒ ํŠน์ง•์„ ์ง์ ‘ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋Š” ๊ฐ vertex์— ์ƒ‰์ƒ์„ ์ง€์ •ํ•˜์—ฌ ๋ Œ๋”๋งํ•ฉ๋‹ˆ๋‹ค.
  • ์ฃผ๋กœ ๊ฐ vertex์— ๊ฐœ๋ณ„ ์ƒ‰์ƒ์„ ์ง€์ •ํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ์— ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค.

๋‘ ๋ฐฉ์‹์˜ ์ฃผ์š” ์ฐจ์ด๋Š” ํ…์Šค์ฒ˜๊ฐ€ ์ •์˜๋˜๋Š” ๋ฐฉ์‹๊ณผ ๊ทธ์— ๋”ฐ๋ฅธ ๋ Œ๋”๋ง ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค. TexturesUV๋Š” UV ์ขŒํ‘œ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ…์Šค์ฒ˜ ๋งต์„ ๋งคํ•‘ํ•˜๋Š” ๋ฐ˜๋ฉด, TexturesVertex๋Š” ๊ฐ vertex์— ์ง์ ‘ ์ƒ‰์ƒ ์ •๋ณด๋ฅผ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.

Mesh texturing options (Vertex textures, Texture map + Vertex UV coordinates, Texture Atlas)

pytorch3d๋Š” Mesh texturing์„ ์œ„ํ•ด ์—ฌ๋Ÿฌ ์˜ต์…˜์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

  • Vertex textures (TexturesVertex): ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ ๊ฐ ์ •์ ์— ๋Œ€ํ•ด d์ฐจ์› ํ…์Šค์ฒ˜๋ฅผ ๊ฐ€์ง€๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, RGB ์ƒ‰์ƒ์ด d์ฐจ์›์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  3๊ฐœ์˜ ์ •์ ์— ์กด์žฌํ•˜๋Š” d์ฐจ์› ํ…์Šค์ฒ˜๋Š” face๋ฅผ ๊ฐ€๋กœ์งˆ๋Ÿฌ ๋ณด๊ฐ„๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” N x V x D ํ…์„œ๋กœ ํ‘œํ˜„๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

image

  • Texture map + Vertex UV coordinates (TexturesUV): ๋‘ ๋ฒˆ์งธ ๋ฐฉ๋ฒ•์€ ์ •์  UV ์ขŒํ‘œ์™€ ์ „์ฒด face์— ๋Œ€ํ•œ ๋‹จ์ผ ํ…์Šค์ฒ˜ ๋งต์„ ๊ฐ€์ง€๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. face์˜ ํŠน์ • ์ง€์ ์— ๋Œ€ํ•ด ์ƒ‰์ƒ์€ UV ์ขŒํ‘œ๋ฅผ ๋ณด๊ฐ„ํ•œ ๋‹ค์Œ ํ…์Šค์ฒ˜ ๋งต์—์„œ ์ƒ˜ํ”Œ๋งํ•˜์—ฌ ๊ณ„์‚ฐ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ํ‘œํ˜„์€ ๋‘ ๊ฐœ์˜ ํ…์„œ๋ฅผ ํ•„์š”๋กœ ํ•˜๋ฉฐ mesh๋‹น ํ•˜๋‚˜์˜ ํ…์Šค์ฒ˜ ๋งต๋งŒ ์ง€์›ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

image

  • Texture Atlas: ๋ณด๋‹ค ๋ณต์žกํ•œ ๊ฒฝ์šฐ, ์˜ˆ๋ฅผ ๋“ค์–ด ShapeNet meshes์˜ ๊ฒฝ์šฐ, mesh๋‹น ์—ฌ๋Ÿฌ ๊ฐœ์˜ ํ…์Šค์ฒ˜ ๋งต์ด ์žˆ์œผ๋ฉฐ ์ผ๋ถ€ face๋Š” ํ…์Šค์ฒ˜๊ฐ€ ์—†๊ณ  ๋‹ค๋ฅธ face๋Š” ํ…์Šค์ฒ˜๊ฐ€ ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๊ฒฝ์šฐ, ๋ณด๋‹ค ์œ ์—ฐํ•œ ํ‘œํ˜„์€ ํ…์Šค์ฒ˜ ์•„ํ‹€๋ผ์Šค(texture atlas)์ž…๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ ๊ฐ face๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ๊ฒฐ์ •ํ•œ ํ…์Šค์ฒ˜ ํ•ด์ƒ๋„ R์— ๋”ฐ๋ฅธ R x R ํ…์Šค์ฒ˜ ๋งต์œผ๋กœ ํ‘œํ˜„๋ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์†Œํ”„ํŠธ ๋ž˜์Šคํ„ฐ๋ผ์ด์ €(soft rasterizer) ๊ตฌํ˜„์—์„œ ์˜๊ฐ์„ ๋ฐ›์•˜์Šต๋‹ˆ๋‹ค. face์˜ ํŠน์ • ์ง€์ ์— ๋Œ€ํ•ด ํ…์Šค์ฒ˜ ๊ฐ’์€ ํ•ด๋‹น ์ง€์ ์˜ ์ค‘์  ์ขŒํ‘œ(barycentric coordinates)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ face์˜ ํ…์Šค์ฒ˜ ๋งต์—์„œ ์ƒ˜ํ”Œ๋งํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ํ‘œํ˜„์€ N x F x R x R x 3 ํ˜•ํƒœ์˜ ํ•˜๋‚˜์˜ ํ…์„œ๋ฅผ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค.

image

# sugar_scene/sugar_model.py

class SuGaR(nn.Module):

...

    @property
    def texture_features(self):
        if not self._texture_initialized:
            self.update_texture_features()
        return self.sh_coordinates[self.point_idx_per_pixel]
    
    @property
    def mesh(self):        
        textures_uv = TexturesUV(
            maps=SH2RGB(self.texture_features[..., 0, :][None]), #texture_img[None]),
            verts_uvs=self.verts_uv[None],
            faces_uvs=self.faces_uv[None],
            sampling_mode='nearest',
            )
        
        return Meshes(
            verts=[self.triangle_vertices],   
            faces=[self.triangles],
            textures=textures_uv,
        )
        
    @property
    def surface_mesh(self):
        # Create a Meshes object
        surface_mesh = Meshes(
            verts=[self._points.to(self.device)],   
            faces=[self._surface_mesh_faces.to(self.device)],
            textures=TexturesVertex(verts_features=self._vertex_colors[None].clamp(0, 1).to(self.device)),
            # verts_normals=[verts_normals.to(rc.device)],
            )
        return surface_mesh

Leave a comment