Yin的笔记本

vuePress-theme-reco Howard Yin    2021 - 2026
Yin的笔记本 Yin的笔记本

Choose mode

  • dark
  • auto
  • light
Home
Category
  • CNCF
  • Docker
  • namespaces
  • Kubernetes
  • Kubernetes对象
  • MyIdeas
  • Linux
  • Revolution
  • WebRTC
  • 云计算
  • 人工智能
  • 分布式
  • 图像处理
  • 图形学
  • 微服务
  • 数学
  • OJ笔记
  • 博弈论
  • 形式语言与自动机
  • 数据库
  • 服务器运维
  • 编程语言
  • C
  • Git
  • Go
  • Java
  • JavaScript
  • Nvidia
  • Python
  • Rust
  • Tex
  • Shell
  • Vue
  • 视频编解码
  • 计算机网络
  • SDN
  • 论文笔记
  • 边缘计算
  • 讨论
  • 量子信息技术
Tag
TimeLine
About
查看源码
author-avatar

Howard Yin

307

Article

156

Tag

Home
Category
  • CNCF
  • Docker
  • namespaces
  • Kubernetes
  • Kubernetes对象
  • MyIdeas
  • Linux
  • Revolution
  • WebRTC
  • 云计算
  • 人工智能
  • 分布式
  • 图像处理
  • 图形学
  • 微服务
  • 数学
  • OJ笔记
  • 博弈论
  • 形式语言与自动机
  • 数据库
  • 服务器运维
  • 编程语言
  • C
  • Git
  • Go
  • Java
  • JavaScript
  • Nvidia
  • Python
  • Rust
  • Tex
  • Shell
  • Vue
  • 视频编解码
  • 计算机网络
  • SDN
  • 论文笔记
  • 边缘计算
  • 讨论
  • 量子信息技术
Tag
TimeLine
About
查看源码
  • VGGT 原理解析

    • Aggregator 结构解析
      • patch_embed
      • camera_token和register_token
      • patch位置信息
      • Attention计算
    • CameraHead:迭代式的相机参数细化
      • DPTHead 结构解析
        • 第 1 步:从多层 token 中选 4 层
        • 第 2 步:投影 + reshape 成 2D 特征图
        • 第 3 步:自底向上逐级融合(RefineNet 风格)
        • 第 4 步:恢复到原始像素分辨率
        • 第 5 步:最终卷积 + 激活分离
        • 一句话总结
      • TrackHead结构解析
        • 轨迹初始化 (Initialization)
        • 构建相关性金字塔(CorrBlock)
        • 迭代细化 (Iterative Refinement)
        • 预测可见性与置信度 (Visibility & Confidence)
        • 总结

    VGGT 原理解析

    vuePress-theme-reco Howard Yin    2021 - 2026

    VGGT 原理解析


    Howard Yin 2026-03-17 06:21:03 图形学VGGT3D重建

    原版VGGT 由一个骨干网络和4个Head构成:

    self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
    
    self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None
    self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") if enable_point else None
    self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") if enable_depth else None
    self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) if enable_track else None
    
    1
    2
    3
    4
    5
    6

    其运行时输入为一批图片和一批用于track的查询点,输出为相机位置编码和深度+点云及其置信度,如果有track查询点输入还会输出查询点在每张图片上的像素位置、可见性及置信度:

    def forward(self, images: torch.Tensor, query_points: torch.Tensor = None):
        """
        Forward pass of the VGGT model.
    
        Args:
            images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
                B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
            query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
                Shape: [N, 2] or [B, N, 2], where N is the number of query points.
                Default: None
    
        Returns:
            dict: A dictionary containing the following predictions:
                - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
                - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
                - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
                - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
                - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
                - images (torch.Tensor): Original input images, preserved for visualization
    
                If query_points is provided, also includes:
                - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
                - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
                - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
        """        
        # If without batch dimension, add it
        if len(images.shape) == 4:
            images = images.unsqueeze(0)
            
        if query_points is not None and len(query_points.shape) == 2:
            query_points = query_points.unsqueeze(0)
    
        aggregated_tokens_list, patch_start_idx = self.aggregator(images)
    
        predictions = {}
    
        with torch.cuda.amp.autocast(enabled=False):
            if self.camera_head is not None:
                pose_enc_list = self.camera_head(aggregated_tokens_list)
                predictions["pose_enc"] = pose_enc_list[-1]  # pose encoding of the last iteration
                predictions["pose_enc_list"] = pose_enc_list
                
            if self.depth_head is not None:
                depth, depth_conf = self.depth_head(
                    aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
                )
                predictions["depth"] = depth
                predictions["depth_conf"] = depth_conf
    
            if self.point_head is not None:
                pts3d, pts3d_conf = self.point_head(
                    aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
                )
                predictions["world_points"] = pts3d
                predictions["world_points_conf"] = pts3d_conf
    
        if self.track_head is not None and query_points is not None:
            track_list, vis, conf = self.track_head(
                aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
            )
            predictions["track"] = track_list[-1]  # track of the last iteration
            predictions["vis"] = vis
            predictions["conf"] = conf
    
        if not self.training:
            predictions["images"] = images  # store the images for visualization during inference
    
        return predictions
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68

    VGGT的结构决定了它只能输入固定尺寸的图片。 原版VGGT参数输入尺寸为518x518,patch_size为14x14。 在原版代码中,对于任意尺寸的图片,在输入VGGT之前是先由load_and_preprocess_images_square padding为正方形再缩放到1024x1024,再在run_VGGT 里缩放到518x518。

    # Aggregator 结构解析

    # patch_embed

    图片输入先经过一个patch_embed:

    B, S, C_in, H, W = images.shape
    
    if C_in != 3:
        raise ValueError(f"Expected 3 input channels, got {C_in}")
    
    # Normalize images and reshape for patch embed
    images = (images - self._resnet_mean) / self._resnet_std
    
    # Reshape to [B*S, C, H, W] for patch embedding
    images = images.view(B * S, C_in, H, W)
    patch_tokens = self.patch_embed(images)
    
    if isinstance(patch_tokens, dict):
        patch_tokens = patch_tokens["x_norm_patchtokens"]
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14

    这个patch_embed就是DINOv2或者卷积:

    if "conv" in patch_embed:
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
    else:
        vit_models = {
            "dinov2_vitl14_reg": vit_large,
            "dinov2_vitb14_reg": vit_base,
            "dinov2_vits14_reg": vit_small,
            "dinov2_vitg2_reg": vit_giant2,
        }
    
        self.patch_embed = vit_models[patch_embed](
            img_size=img_size,
            patch_size=patch_size,
            num_register_tokens=num_register_tokens,
            interpolate_antialias=interpolate_antialias,
            interpolate_offset=interpolate_offset,
            block_chunks=block_chunks,
            init_values=init_values,
        )
    
        # Disable gradient updates for mask token
        if hasattr(self.patch_embed, "mask_token"):
            self.patch_embed.mask_token.requires_grad_(False)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23

    # camera_token和register_token

    再拿两个camera_token和register_token给拼在patch_embed后面:

    _, P, C = patch_tokens.shape
    
    # Expand camera and register tokens to match batch size and sequence length
    camera_token = slice_expand_and_flatten(self.camera_token, B, S)
    register_token = slice_expand_and_flatten(self.register_token, B, S)
    
    # Concatenate special tokens with patch tokens
    tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
    
    1
    2
    3
    4
    5
    6
    7
    8

    这里的slice_expand_and_flatten就是根据patch_embed的尺寸把token_tensor复制几遍:

    def slice_expand_and_flatten(token_tensor, B, S):
        """
        Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
        1) Uses the first position (index=0) for the first frame only
        2) Uses the second position (index=1) for all remaining frames (S-1 frames)
        3) Expands both to match batch size B
        4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
           followed by (S-1) second-position tokens
        5) Flattens to (B*S, X, C) for processing
    
        Returns:
            torch.Tensor: Processed tokens with shape (B*S, X, C)
        """
    
        # Slice out the "query" tokens => shape (1, 1, ...)
        query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
        # Slice out the "other" tokens => shape (1, S-1, ...)
        others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
        # Concatenate => shape (B, S, ...)
        combined = torch.cat([query, others], dim=1)
    
        # Finally flatten => shape (B*S, ...)
        combined = combined.view(B * S, *combined.shape[2:])
        return combined
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24

    这两个camera_token和register_token就是两个可训练的nn.Parameter:

    # Note: We have two camera tokens, one for the first frame and one for the rest
    # The same applies for register tokens
    self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
    self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
    
    1
    2
    3
    4

    并且经过normal_初始化参数:

    # Initialize parameters with small values
    nn.init.normal_(self.camera_token, std=1e-6)
    nn.init.normal_(self.register_token, std=1e-6)
    
    1
    2
    3

    所以,这两个camera_token和register_token就是把可训练的两个token拼接在图片的patch_embed后面输入给transformer,每张图片后面都拼了一个相同的token。

    # patch位置信息

    接下来获取patch的位置信息:

    pos = None
    if self.rope is not None:
        pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
    
    if self.patch_start_idx > 0:
        # do not use position embedding for special tokens (camera and register tokens)
        # so set pos to 0 for the special tokens
        pos = pos + 1
        pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
        pos = torch.cat([pos_special, pos], dim=1)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10

    其实就是每个patch在图像上的坐标:

    class PositionGetter:
        """Generates and caches 2D spatial positions for patches in a grid.
    
        This class efficiently manages the generation of spatial coordinates for patches
        in a 2D grid, caching results to avoid redundant computations.
    
        Attributes:
            position_cache: Dictionary storing precomputed position tensors for different
                grid dimensions.
        """
    
        def __init__(self):
            """Initializes the position generator with an empty cache."""
            self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
    
        def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
            """Generates spatial positions for a batch of patches.
    
            Args:
                batch_size: Number of samples in the batch.
                height: Height of the grid in patches.
                width: Width of the grid in patches.
                device: Target device for the position tensor.
    
            Returns:
                Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
                for each position in the grid, repeated for each batch item.
            """
            if (height, width) not in self.position_cache:
                y_coords = torch.arange(height, device=device)
                x_coords = torch.arange(width, device=device)
                positions = torch.cartesian_prod(y_coords, x_coords)
                self.position_cache[height, width] = positions
    
            cached_positions = self.position_cache[height, width]
            return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36

    # Attention计算

    经过多个Attention模块,每个Attention模块里都有几个全局attention和帧内attention子模块,根据aa_order决定是全局attention还是帧内attention,最后output_list输出attention后的所有token:

    # update P because we added special tokens
    _, P, C = tokens.shape
    
    frame_idx = 0
    global_idx = 0
    output_list = []
    
    for _ in range(self.aa_block_num):
        for attn_type in self.aa_order:
            if attn_type == "frame":
                tokens, frame_idx, frame_intermediates = self._process_frame_attention(
                    tokens, B, S, P, C, frame_idx, pos=pos
                )
            elif attn_type == "global":
                tokens, global_idx, global_intermediates = self._process_global_attention(
                    tokens, B, S, P, C, global_idx, pos=pos
                )
            else:
                raise ValueError(f"Unknown attention type: {attn_type}")
    
        for i in range(len(frame_intermediates)):
            # concat frame and global intermediates, [B x S x P x 2C]
            concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
            output_list.append(concat_inter)
    
    del concat_inter
    del frame_intermediates
    del global_intermediates
    return output_list, self.patch_start_idx
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29

    默认值为一个帧内加一个全局Attention:

    aa_order=["frame", "global"]
    
    1

    帧内Attention和全局Attention模型结构都一样:

    self.frame_blocks = nn.ModuleList(
        [
            block_fn(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                ffn_bias=ffn_bias,
                init_values=init_values,
                qk_norm=qk_norm,
                rope=self.rope,
            )
            for _ in range(depth)
        ]
    )
    
    self.global_blocks = nn.ModuleList(
        [
            block_fn(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                ffn_bias=ffn_bias,
                init_values=init_values,
                qk_norm=qk_norm,
                rope=self.rope,
            )
            for _ in range(depth)
        ]
    )
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33

    区别在于推断时token的重排方式不一样:

    def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
        """
        Process frame attention blocks. We keep tokens in shape (B*S, P, C).
        """
        # If needed, reshape tokens or positions:
        if tokens.shape != (B * S, P, C):
            tokens = tokens.view(B, S, P, C).view(B * S, P, C)
    
        if pos is not None and pos.shape != (B * S, P, 2):
            pos = pos.view(B, S, P, 2).view(B * S, P, 2)
    
        intermediates = []
    
        # by default, self.aa_block_size=1, which processes one block at a time
        for _ in range(self.aa_block_size):
            if self.training:
                tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
            else:
                tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
            frame_idx += 1
            intermediates.append(tokens.view(B, S, P, C))
    
        return tokens, frame_idx, intermediates
    
    def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
        """
        Process global attention blocks. We keep tokens in shape (B, S*P, C).
        """
        if tokens.shape != (B, S * P, C):
            tokens = tokens.view(B, S, P, C).view(B, S * P, C)
    
        if pos is not None and pos.shape != (B, S * P, 2):
            pos = pos.view(B, S, P, 2).view(B, S * P, 2)
    
        intermediates = []
    
        # by default, self.aa_block_size=1, which processes one block at a time
        for _ in range(self.aa_block_size):
            if self.training:
                tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
            else:
                tokens = self.global_blocks[global_idx](tokens, pos=pos)
            global_idx += 1
            intermediates.append(tokens.view(B, S, P, C))
    
        return tokens, global_idx, intermediates
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46

    注意这两个函数唯一的区别在于开头几行:

    def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
        ......
        if tokens.shape != (B * S, P, C):
            tokens = tokens.view(B, S, P, C).view(B * S, P, C)
    
        if pos is not None and pos.shape != (B * S, P, 2):
            pos = pos.view(B, S, P, 2).view(B * S, P, 2)
    
        ......
    
    def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
        ......
        if tokens.shape != (B, S * P, C):
            tokens = tokens.view(B, S, P, C).view(B, S * P, C)
    
        if pos is not None and pos.shape != (B, S * P, 2):
            pos = pos.view(B, S, P, 2).view(B, S * P, 2)
    
        ......
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19

    帧内Attention的batch维是B*S,序列长度是P(单帧内部token),因此每个batch是在同一帧内计算attention。 全局Attention的batch 维是B,序列长度是S*P(所有帧token串起来),因此每个batch是在所有帧的所有token间计算attention。

    # CameraHead:迭代式的相机参数细化

    输入是 aggregator 输出的最后一层 token([B,S,P,2C]),CameraHead 只取第 0 个 token(相机 token),得到 pose_tokens: [B,S,2C],目标输出是每帧 9 维相机编码:[T(3), quat(4), FoV(2)]:

    def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
        """
        Forward pass to predict camera parameters.
    
        Args:
            aggregated_tokens_list (list): List of token tensors from the network;
                the last tensor is used for prediction.
            num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
    
        Returns:
            list: A list of predicted camera encodings (post-activation) from each iteration.
        """
        # Use tokens from the last block for camera prediction.
        tokens = aggregated_tokens_list[-1]
    
        # Extract the camera tokens
        pose_tokens = tokens[:, :, 0]
        pose_tokens = self.token_norm(pose_tokens)
    
        pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
        return pred_pose_enc_list
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21

    CameraHead.trunk_fn是CameraHead的核心功能,其包含迭代增量式相机回归。

    下面是对 trunk_fn 迭代细化机制的逐步解读。


    # 总览

    trunk_fn 的设计思想来自 DiT(Diffusion Transformer)中的 AdaLN(Adaptive Layer Norm)调制 和 RAFT 中的 迭代增量更新。它不是一次性回归出相机参数,而是每一轮都用"当前估计"去调制 token,然后预测一个增量,逐步逼近真实值。


    # Step 0:初始化

    B, S, C = pose_tokens.shape   # B=batch, S=帧数, C=2048
    pred_pose_enc = None           # 还没有任何相机估计
    pred_pose_enc_list = []        # 收集每轮输出
    
    1
    2
    3

    pose_tokens 是 aggregator 输出的 camera token(每帧 1 个),经过 LayerNorm 后传入。它编码了"这个场景中每帧的相机应该是什么"的全局信息,但还没有被解码成具体的 9D 参数。


    接下来运行多轮:

    for _ in range(num_iterations):
    
    1

    # Step 1:构造条件输入(当前相机估计)

    if pred_pose_enc is None:
        module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
    else:
        pred_pose_enc = pred_pose_enc.detach()
        module_input = self.embed_pose(pred_pose_enc)
    
    1
    2
    3
    4
    5
    • 第 1 轮:没有任何先验,用可学习的 empty_pose_tokens通过 embed_pose(Linear 9->2048)映射到 token 空间(全零初始化,形状 [1,1,9]广播到 [B,S,9],所以每个相机token相同)。
    • 后续轮次:把上一轮的 9D 相机预测映射到 token 空间(detach() 切断跨迭代梯度,让每轮独立优化,类似 RAFT 的做法)。

    此时 module_input 形状为 [B, S, 2048],代表"当前对相机参数的最优估计"在 token 空间的表达。


    # Step 2:生成 AdaLN 调制参数

    shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
    
    1

    poseLN_modulation 是 SiLU -> Linear(2048, 3*2048),把条件输入变成三组参数:

    • shift_msa [B,S,2048]:对特征做平移
    • scale_msa [B,S,2048]:对特征做缩放
    • gate_msa [B,S,2048]:控制调制后特征的强度

    # Step 3:用 AdaLN 调制 pose token

    pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
    pose_tokens_modulated = pose_tokens_modulated + pose_tokens
    
    1
    2

    展开来看:

    1. self.adaln_norm(pose_tokens) — 先做 LayerNorm
    2. modulate(x, shift, scale) 即 x * (1 + scale) + shift — 用当前相机估计来自适应地缩放和偏移特征
    3. 再乘 gate_msa — 控制"调制信号"的强度
    4. 最后加上原始 pose_tokens 作为残差连接

    第 1 轮时条件接近零向量,调制几乎不生效,trunk 看到的近乎原始 token;后续轮次,条件越来越准,调制越来越有针对性。


    # Step 4:通过 Transformer Trunk 提炼

    pose_tokens_modulated = self.trunk(pose_tokens_modulated)
    
    1

    self.trunk 是 4 层 Block(标准 Transformer block,含 self-attention + FFN)。输入形状 [B,S,2048],S 个帧之间互相 attend。

    这一步让不同帧的相机估计互相参考——比如"如果第 1 帧朝左,第 2 帧应该朝右"这类跨帧几何约束,在这里被隐式建模。


    # Step 5:预测增量并累加

    pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
    
    if pred_pose_enc is None:
        pred_pose_enc = pred_pose_enc_delta
    else:
        pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
    
    1
    2
    3
    4
    5
    6

    pose_branch 是 MLP(2048 -> 1024 -> 9),把 trunk 输出映射回 9D 相机编码空间。

    关键:输出的是pred_pose_enc_delta增量而非绝对值。 网络每次输出的只是"修正量"。


    # Step 6:激活并收集

    activated_pose = activate_pose(
        pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
    )
    pred_pose_enc_list.append(activated_pose)
    
    1
    2
    3
    4

    对 9D 参数的三个部分分别施加激活:

    • T[:3] — linear(平移无约束)
    • quat[3:7] — linear(四元数无约束,后续转矩阵时会归一化)
    • FoV[7:9] — relu(视场角必须为正)

    每一轮的 activated_pose 都被记录,训练时可以对所有轮次施加监督(deep supervision),推理时取最后一轮。

    # DPTHead 结构解析

    point_head和depth_head都是DPTHead,只是输出通道数不一样。

    DPT(Dense Prediction Transformer)源自论文 "Vision Transformers for Dense Prediction"(Ranftl et al., 2021)。核心思想:从 Transformer 不同深度抽取多尺度特征,用类似 FPN 的逐级融合恢复到像素级分辨率。

    VGGT 的 DPTHead 就是在这个框架上做的,但输入不是普通 ViT 特征,而是 Aggregator 产出的跨帧聚合 token。


    # 第 1 步:从多层 token 中选 4 层

        intermediate_layer_idx: List[int] = [4, 11, 17, 23],
    
    1

    从 Aggregator 的 24 层输出中选第 4、11、17、23 层——分别代表浅层、中层、深层、最深层的特征。每层 token 形状为 [B, S, P, 2C],去掉特殊 token 后只保留 patch token。

    # 第 2 步:投影 + reshape 成 2D 特征图

            for layer_idx in self.intermediate_layer_idx:
                x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
                // ...
                x = x.reshape(B * S, -1, x.shape[-1])
                x = self.norm(x)
                x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
                x = self.projects[dpt_idx](x)
                if self.pos_embed:
                    x = self._apply_pos_embed(x, W, H)
                x = self.resize_layers[dpt_idx](x)
                out.append(x)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11

    对每层:

    1. LayerNorm 归一化
    2. reshape 成 [B*S, 2C, patch_h, patch_w](恢复空间结构)
    3. 1x1 Conv 投影到各自的通道数 [256, 512, 1024, 1024]
    4. 可选加入位置编码(UV 正弦余弦嵌入)
    5. resize 到不同空间尺度:
      • 第 0 层:ConvTranspose2d stride=4(放大 4x)
      • 第 1 层:ConvTranspose2d stride=2(放大 2x)
      • 第 2 层:Identity(不变)
      • 第 3 层:Conv2d stride=2(缩小 2x)

    这样就得到 4 张不同尺度的特征图。

    # 第 3 步:自底向上逐级融合(RefineNet 风格)

        def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
            layer_1, layer_2, layer_3, layer_4 = features
    
            layer_1_rn = self.scratch.layer1_rn(layer_1)
            layer_2_rn = self.scratch.layer2_rn(layer_2)
            layer_3_rn = self.scratch.layer3_rn(layer_3)
            layer_4_rn = self.scratch.layer4_rn(layer_4)
    
            out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
            // ...
            out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
            // ...
            out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
            // ...
            out = self.scratch.refinenet1(out, layer_1_rn)
            // ...
            out = self.scratch.output_conv1(out)
            return out
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    • 先对 4 层各做 3x3 Conv(layer_rn)统一到 256 通道
    • 然后从最深层到最浅层逐级融合:
      • refinenet4 → 上采样 → 加 layer_3 → refinenet3 → ... → refinenet1
    • 每个 FeatureFusionBlock 内部是:
    class FeatureFusionBlock(nn.Module):
        // ...
        def forward(self, *xs, size=None):
            output = xs[0]
            if self.has_residual:
                res = self.resConfUnit1(xs[1])
                output = self.skip_add.add(output, res)
            output = self.resConfUnit2(output)
            // ... bilinear upsample ...
            output = self.out_conv(output)
            return output
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11

    即:上一级输出 + 残差卷积处理当前级 → 残差卷积 → 双线性上采样 → 1x1 Conv。

    这个过程和 U-Net / FPN 类似,但全部用残差卷积单元(ResidualConvUnit:两层 3x3 Conv + skip connection)。

    # 第 4 步:恢复到原始像素分辨率

            out = custom_interpolate(
                out,
                (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
                mode="bilinear",
                align_corners=True,
            )
    
    1
    2
    3
    4
    5
    6

    融合后的特征图再做一次双线性插值,恢复到 H x W(或 H/2 x W/2,取决于 down_ratio)。

    # 第 5 步:最终卷积 + 激活分离

            out = self.scratch.output_conv2(out)
            preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
    
    1
    2
    • output_conv2:3x3 Conv → ReLU → 1x1 Conv,输出 output_dim 个通道
    • activate_head 把最后一个通道分出来当 置信度,其余通道当预测值:
    任务头 output_dim 预测值 激活 置信度激活
    DepthHead 2 1 通道(深度) exp(保证正) 1 + exp(x)
    PointHead 4 3 通道(xyz) inv_log(保符号大范围) 1 + exp(x)

    # 一句话总结

    DPTHead = 多层 token 恢复为多尺度 2D 特征图 → 自底向上残差融合(像 FPN)→ 上采样到像素级 → 分离出预测值和置信度。它本质上是把 Transformer 的"扁平 token 序列"重新恢复成"空间金字塔",然后用类 U-Net 的逐级融合做稠密预测。

    # TrackHead结构解析

    TrackHead 的核心原理是基于特征相关性(Correlation)和时空 Transformer 的迭代式点轨迹细化(Iterative Refinement)。它的设计深受 Co-Tracker 和 VGGSfM 的启发。

    简单来说,它的工作流程是:提取稠密特征图 -> 在参考帧采特征 -> 粗略猜测轨迹 -> 局部搜索匹配(算相关性) -> Transformer 预测修正量 -> 循环细化。

    其由DPTHead和Tracker模块两个部分组成:

    # Feature extractor based on DPT architecture
    # Processes tokens into feature maps for tracking
    self.feature_extractor = DPTHead(
        dim_in=dim_in,
        patch_size=patch_size,
        features=features,
        feature_only=True,  # Only output features, no activation
        down_ratio=2,  # Reduces spatial dimensions by factor of 2
        pos_embed=False,
    )
    
    # Tracker module that predicts point trajectories
    # Takes feature maps and predicts coordinates and visibility
    self.tracker = BaseTrackerPredictor(
        latent_dim=features,  # Match the output_dim of feature extractor
        predict_conf=predict_conf,
        stride=stride,
        corr_levels=corr_levels,
        corr_radius=corr_radius,
        hidden_size=hidden_size,
    )
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21

    这里 down_ratio=2 意味着输出的特征图分辨率是原图的一半(例如输入 518x518,特征图就是 259x259),这在保证跟踪精度的同时大幅节省了显存。

    推断时先用这个DPTHead提取稠密特征,再用Tracker模块输出最终结果:

    def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
        """
        Forward pass of the TrackHead.
    
        Args:
            aggregated_tokens_list (list): List of aggregated tokens from the backbone.
            images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
                                    B = batch size, S = sequence length.
            patch_start_idx (int): Starting index for patch tokens.
            query_points (torch.Tensor, optional): Initial query points to track.
                                                    If None, points are initialized by the tracker.
            iters (int, optional): Number of refinement iterations. If None, uses self.iters.
    
        Returns:
            tuple:
                - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
                - vis_scores (torch.Tensor): Visibility scores for tracked points.
                - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
        """
        B, S, _, H, W = images.shape
    
        # Extract features from tokens
        # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
        feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
    
        # Use default iterations if not specified
        if iters is None:
            iters = self.iters
    
        # Perform tracking using the extracted features
        coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters)
    
        return coord_preds, vis_scores, conf_scores
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33

    这个Tracker模块是TrackHead的核心,它的原理来自 CoTracker / RAFT 家族:基于相关性的迭代坐标细化。

    结合代码,我们可以把它的原理拆解为以下几个个核心步骤:

    # 轨迹初始化 (Initialization)

    进入 BaseTrackerPredictor 后,模型需要一个初始的“猜测”。 它直接把第一帧(参考帧)的 query_points 复制到所有帧作为初始坐标,并在第一帧的特征图上“抠”出一个特征来,作为初始的跟踪特征(track_feats)。

        def forward(self, query_points, fmaps=None, iters=6, ...):
            B, N, D = query_points.shape
            B, S, C, HH, WW = fmaps.shape
            // ...
            query_points = query_points / float(self.stride)
    
            # Init with coords as the query points
            coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
    
            # Sample/extract the features of the query points in the query frame
            query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
    
            # init track feats by query feats
            track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1)  # B, S, N, C
    
            fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    • query_points([B,N,2])是你想要跟踪的像素坐标
    • coords 初始化为"所有帧都在 query 坐标位置"——即假设点不动
    • query_track_feat:在第 0 帧(query 帧)采样 query 点的特征,作为"要找什么"
    • track_feats:每个点在每帧的特征表示,初始时 copy query 特征
    • CorrBlock:对特征图构建多尺度金字塔

    # 构建相关性金字塔(CorrBlock)

    CorrBlock的__init__就是在对 [B,S,C,HH,WW] 的特征图反复做 2x 平均池化放入self.fmaps_pyramid中:

    class CorrBlock:
        def __init__(self, fmaps, num_levels=4, radius=4, ...):
            // ...
            self.fmaps_pyramid = [fmaps]  # level 0 is full resolution
            current_fmaps = fmaps
            for i in range(num_levels - 1):
                // ...
                current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
                self.fmaps_pyramid.append(current_fmaps)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9

    这相当于建立了 num_levels 层特征金字塔,越往上分辨率越低但是每个像素包含越大邻近区域的信息。 基于这个特征金字塔,后续操作让 tracker 能同时看到大范围(粗粒度)和精细(细粒度)的匹配信息。

    # 迭代细化 (Iterative Refinement)

    这是 TrackHead 最核心的机制。模型会循环 iters(默认 4 或 6)次,每次都在当前坐标附近“看一看”,然后预测一个偏移量(Δx,Δy\Delta x, \Delta yΔx,Δy)来修正坐标。 每次迭代有四个关键操作:

            for _ in range(iters):
                coords = coords.detach()
    
                # ① 相关性采样
                fcorrs = fcorr_fn.corr_sample(track_feats, coords)
    
                # ② 位移编码
                flows = (coords - coords[:, 0:1]).permute(...)
                flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
                flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
    
                # ③ Update Transformer
                transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
                // ...
                delta, _ = self.updateformer(x)
    
                # ④ 更新坐标和特征
                coords = coords + delta_coords_
                track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19

    逐步解释:

    # 操作 ①:局部相关性采样 CorrBlock.corr_sample

    这个函数在CorrBlock中,基于CorrBlock.__init__中建立的特征金字塔进行操作。

    def corr_sample(self, targets, coords):
        """
        Instead of storing the entire correlation pyramid, we compute each level's correlation
        volume, sample it immediately, then discard it. This saves GPU memory.
    
        Args:
            targets: Tensor (B, S, N, C) — features for the current targets.
            coords: Tensor (B, S, N, 2) — coordinates at full resolution.
    
        Returns:
            Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
        """
        B, S, N, C = targets.shape
    
        # If you have multiple track features, split them per level.
        if self.multiple_track_feats:
            targets_split = torch.split(targets, C // self.num_levels, dim=-1)
    
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18

    在每一层金字塔上:

        out_pyramid = []
        for i, fmaps in enumerate(self.fmaps_pyramid):
            # Get current spatial resolution H, W for this pyramid level.
            B, S, C, H, W = fmaps.shape
    
    1
    2
    3
    4

    先用当前点的 track 特征与该层金字塔上的所有像素的特征计算 dot-product(相关性):

            # Reshape feature maps for correlation computation:
            # fmap2s: (B, S, C, H*W)
            fmap2s = fmaps.view(B, S, C, H * W)
            # Choose appropriate target features.
            fmap1 = targets_split[i] if self.multiple_track_feats else targets  # shape: (B, S, N, C)
    
            # Compute correlation directly
            corrs = compute_corr_level(fmap1, fmap2s, C)
            corrs = corrs.view(B, S, N, H, W)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9

    然后在当前预测的 coords 位置,提取周围一个局部窗口(self.radius)内的特征图,与当前的 track_feats 计算内积(相似度):

            # Prepare sampling grid:
            # Scale down the coordinates for the current level.
            centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
            # Make sure our precomputed delta grid is on the same device/dtype.
            delta_lvl = self.delta.to(coords.device).to(coords.dtype)
            # Now the grid for grid_sample is:
            # coords_lvl = centroid_lvl + delta_lvl   (broadcasted over grid)
            coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
    
            # Sample from the correlation volume using bilinear interpolation.
            # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
            corrs_sampled = bilinear_sampler(
                corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
            )
            # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
            corrs_sampled = corrs_sampled.view(B, S, N, -1)  # Now shape: (B, S, N, (2r+1)^2)
            out_pyramid.append(corrs_sampled)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17

    最后聚合输出:

        # Concatenate all levels along the last dimension.
        out = torch.cat(out_pyramid, dim=-1).contiguous()
        return out
    
    1
    2
    3

    于是这里的out_pyramid就是当前点的 track 特征在特征金字塔的每一层的当前点附近的self.radius范围内的像素特征的相似度。 输出out_pyramid的形状 [B, S, N, L],对 N 个跟踪点,在 S 帧的每一帧上,在当前预估坐标 coords 周围9x9(self.radius=9)的窗口内,在7层金字塔上分别采样局部相关性值,N=9x9x7=567。 之后的操作就基于该相似度指标确定track点的移动方向和距离。

    # 操作 ②:位移编码

    计算当前点相对于初始位置移动了多远,并将其编码。 具体来说,其先将输出的fcorrs(out_pyramid)维度进行修改,然后将每个9x9x7的金字塔特征经过一个MLP,把 567 维的原始相关性向量压缩到 128 维:

                corr_dim = fcorrs.shape[3]  # 567
                fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)  # [B,S,N,567] → [B,N,S,567] → [B*N, S, 567]
                fcorrs_ = self.corr_mlp(fcorrs_)  # [B*N, S, 567] → [B*N, S, 128]
    
    1
    2
    3

    这一步MLP压缩提取出"匹配信号的摘要"。 因为原始 567 维太大且冗余(大部分位置不相关),MLP 学习从中提取有用的匹配模式。

    然后把"当前帧坐标相对于 query 帧的偏移量"编码成正弦余弦 embedding,再拼上归一化后的原始位移值:

                # Movement of current coords relative to query points
                flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
    
                flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
    
                # (In my trials, it is also okay to just add the flows_emb instead of concat)
                flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
    
    1
    2
    3
    4
    5
    6
    7

    coords[:, 0:1]是要 track 的像素在第一帧的坐标,和coords相减得到每一帧track到的坐标相对于第一帧的位移。

    get_2d_embedding把 2D 位移 (dx, dy) 用正弦/余弦编码到高维空间,和 Transformer 的位置编码原理一样——让网络能区分不同尺度的位移(小偏移 vs 大偏移):

    def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
        """
        This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
    
        Args:
        - xy: The coordinates to generate the embedding from.
        - C: The size of the embedding.
        - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
    
        Returns:
        - pe: The generated 2D positional embedding.
        """
        B, N, D = xy.shape
        assert D == 2
    
        x = xy[:, :, 0:1]
        y = xy[:, :, 1:2]
        div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
    
        pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
        pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
    
        pe_x[:, :, 0::2] = torch.sin(x * div_term)
        pe_x[:, :, 1::2] = torch.cos(x * div_term)
    
        pe_y[:, :, 0::2] = torch.sin(y * div_term)
        pe_y[:, :, 1::2] = torch.cos(y * div_term)
    
        pe = torch.cat([pe_x, pe_y], dim=2)  # (B, N, C*3)
        if cat_coords:
            pe = torch.cat([xy, pe], dim=2)  # (B, N, C*3+3)
        return pe
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32

    这是用transformer能听懂的语言告诉"这个点已经偏移了多远"。

    # 操作 ③:Update Transformer

    把运动信息 (flows_emb)、局部匹配度 (fcorrs_) 和 当前特征 (track_feats_) 拼接起来:

    # Concatenate them as the input for the transformers
    transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
    
    1
    2

    获取位置编码:

    # 2D positional embed
    # TODO: this can be much simplified
    pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
    sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
    
    sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
    
    1
    2
    3
    4
    5
    6

    其中get_2d_sincos_pos_embed计算整张图的位置编码,sample_features4d根据待track的点在第一帧上的位置coords[:, 0]获取位置编码。 相当于用transformer能听懂的语言告诉"这个点从哪出发"。

    相加后送入 EfficientUpdateFormer:

    x = transformer_input + sampled_pos_emb
    
    # Add the query ref token to the track feats
    query_ref_token = torch.cat(
        [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
    )
    x = x + query_ref_token.to(x.device).to(x.dtype)
    
    # B, N, S, C
    x = rearrange(x, "(b n) s d -> b n s d", b=B)
    
    # Compute the delta coordinates and delta track features
    delta, _ = self.updateformer(x)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13

    这个 Transformer 会在时间(帧与帧之间)和空间(点与点之间)计算Attention,从而输出坐标的修正量 delta_coords 和特征的修正量 delta_feats:

    class EfficientUpdateFormer(nn.Module):
        // ...
        def forward(self, input_tensor, mask=None):
            tokens = self.input_transform(input_tensor)  # B, N, S, hidden
            // ...
            for i in range(len(self.time_blocks)):
                # Time attention: each point independently attend across frames
                time_tokens = tokens.view(B * N, T, -1)
                time_tokens = self.time_blocks[i](time_tokens)
    
                # Space attention (every few layers): points interact within each frame
                if self.add_space_attn and ...:
                    space_tokens = tokens.permute(0,2,1,3).view(B*T, N, -1)
                    # virtual tokens act as communication bottleneck
                    virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens)
                    virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
                    point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17

    两种注意力交替:

    • Time Attention:把每个点的 S 帧 token 当一条序列做自注意力。让同一个点在不同帧之间交流,理解"这个点的时间轨迹应该是怎样的"
    • Space Attention:把同一帧中所有 N 个点当一条序列做注意力。让同帧的不同点互相交流(比如可能会学习到利用"刚性运动约束"——如果周围点都往右移,那我大概也该往右)。

    空间注意力用了 虚拟 token 瓶颈(64 个 virtual tracks)来避免 O(N^2) 的开销:先 point→virtual(cross-attn),再 virtual self-attn,最后 virtual→point(cross-attn)。

    输出:每个点每帧的 delta(坐标增量 + 特征增量)。

    # 操作 ④:更新

    将算出的坐标delta加到当前的坐标上、特征delta经过一个ffeat_updater加到特征上:

                # Update the track features
                track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
    
                track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3)  # BxSxNxC
    
                # B x S x N x 2
                coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
    
                # Force coord0 as query
                # because we assume the query points should not be changed
                coords[:, 0] = coords_backup[:, 0]
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11

    注意,第一帧(参考帧)的坐标被强制重置为初始值,因为它是不应该变的。

    # 预测可见性与置信度 (Visibility & Confidence)

    在所有迭代结束后,点在每一帧的最终特征 track_feats 已经融合了时空信息。模型直接用两个简单的线性层(Linear + Sigmoid)来判断这个点在当前帧是否可见(比如被遮挡、移出画面),以及预测的置信度。

            # B, S, N
            vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
            if apply_sigmoid:
                vis_e = torch.sigmoid(vis_e)
    
            if self.predict_conf:
                conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
                if apply_sigmoid:
                    conf_e = torch.sigmoid(conf_e)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9

    # 总结

    TrackHead 放弃了传统光流那种一次性回归整张图位移的做法,而是采用了**稀疏点追踪(Sparse Point Tracking)**的范式: 先提特征 -> 猜个位置 -> 看看周围像不像 -> Transformer 综合上下文给出修正建议 -> 挪动位置 -> 再看周围像不像... 如此循环。 这种方法对长视频、大位移和遮挡具有极强的鲁棒性。

    帮助我们改善此页面!
    创建于: 2026-03-16 23:27:00

    更新于: 2026-03-17 06:21:39