Yin的笔记本

vuePress-theme-reco Howard Yin    2021 - 2025
Yin的笔记本 Yin的笔记本

Choose mode

  • dark
  • auto
  • light
Home
Category
  • CNCF
  • Docker
  • namespaces
  • Kubernetes
  • Kubernetes对象
  • Linux
  • MyIdeas
  • Revolution
  • WebRTC
  • 云计算
  • 人工智能
  • 分布式
  • 图像处理
  • 图形学
  • 微服务
  • 数学
  • OJ笔记
  • 博弈论
  • 形式语言与自动机
  • 数据库
  • 服务器运维
  • 编程语言
  • C
  • Git
  • Go
  • Java
  • JavaScript
  • Python
  • Nvidia
  • Shell
  • Tex
  • Rust
  • Vue
  • 视频编解码
  • 计算机网络
  • SDN
  • 论文笔记
  • 讨论
  • 边缘计算
  • 量子信息技术
Tag
TimeLine
About
查看源码
author-avatar

Howard Yin

304

Article

153

Tag

Home
Category
  • CNCF
  • Docker
  • namespaces
  • Kubernetes
  • Kubernetes对象
  • Linux
  • MyIdeas
  • Revolution
  • WebRTC
  • 云计算
  • 人工智能
  • 分布式
  • 图像处理
  • 图形学
  • 微服务
  • 数学
  • OJ笔记
  • 博弈论
  • 形式语言与自动机
  • 数据库
  • 服务器运维
  • 编程语言
  • C
  • Git
  • Go
  • Java
  • JavaScript
  • Python
  • Nvidia
  • Shell
  • Tex
  • Rust
  • Vue
  • 视频编解码
  • 计算机网络
  • SDN
  • 论文笔记
  • 讨论
  • 边缘计算
  • 量子信息技术
Tag
TimeLine
About
查看源码
  • Torch-Pruning解析

    • 底层调用
      • DG.register_customized_layer
      • DG.build_dependency
      • pruning_group = DG.get_pruning_group
      • pruning_group.exec()

[Torch-Pruning](https://github.com/VainF/Torch-Pruning)解析

vuePress-theme-reco Howard Yin    2021 - 2025

Torch-Pruning解析


Howard Yin 2023-03-10 13:52:02 人工智能计算加速

# 底层调用

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_pruning as tp
from typing import Sequence

############
# Customize your layer
#
class CustomizedLayer(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.in_dim = in_dim
        self.scale = nn.Parameter(torch.Tensor(self.in_dim))
        self.bias = nn.Parameter(torch.Tensor(self.in_dim))
        self.fc = nn.Linear(self.in_dim, self.in_dim)
    
    def forward(self, x):
        norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()
        x = torch.div(x, norm)
        return self.fc(x * self.scale + self.bias)

    def __repr__(self):
        return "CustomizedLayer(in_dim=%d)"%(self.in_dim)

class FullyConnectedNet(nn.Module):
    """https://github.com/VainF/Torch-Pruning/issues/21"""
    def __init__(self, input_size, num_classes, HIDDEN_UNITS):
        super().__init__()
        self.fc1 = nn.Linear(input_size, HIDDEN_UNITS)
        self.customized_layer = CustomizedLayer(HIDDEN_UNITS)
        self.fc2 = nn.Linear(HIDDEN_UNITS, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.customized_layer(x)
        y_hat = self.fc2(x)
        return y_hat

############################
# Implement your pruning function for the customized layer
# You should implement the following class fucntions:
# 1. prune_out_channels
# 2. prune_in_channels
# 3. get_out_channels
# 4. get_in_channels

class MyPruner(tp.pruner.BasePruningFunc):

    def prune_out_channels(self, layer: CustomizedLayer, idxs: Sequence[int]) -> nn.Module: 
        keep_idxs = list(set(range(layer.in_dim)) - set(idxs))
        keep_idxs.sort()
        layer.in_dim = layer.in_dim-len(idxs)
        layer.scale = torch.nn.Parameter(layer.scale.data.clone()[keep_idxs])
        layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
        tp.prune_linear_in_channels(layer.fc, idxs)
        tp.prune_linear_out_channels(layer.fc, idxs)
        return layer

    def get_out_channels(self, layer):
        return self.in_dim
    
    # identical functions
    prune_in_channels = prune_out_channels
    get_in_channels = get_out_channels
        
model = FullyConnectedNet(128, 10, 256)

DG = tp.DependencyGraph()

# 1. Register your customized layer
my_pruner = MyPruner()
DG.register_customized_layer(
    CustomizedLayer, 
    my_pruner)

# 2. Build dependency graph
DG.build_dependency(model, example_inputs=torch.randn(1,128))

# 3. get a pruning group according to the dependency graph. idxs is the indices of pruned filters.
pruning_group = DG.get_pruning_group( model.fc1, tp.prune_linear_out_channels, idxs=[0, 1, 6] )
print(pruning_group)

# 4. execute this group (prune the model)
pruning_group.exec()
print("The pruned model:\n", model)
print("Output: ", model(torch.randn(1,128)).shape)

assert model.fc1.out_features==253 and model.customized_layer.in_dim==253 and model.fc2.in_features==253
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

DG = tp.DependencyGraph()是整个系统的核心模块,看它的调用流程也就看懂了整个Torch-Pruning的裁剪过程。

从DG = tp.DependencyGraph()开始的调用从上往下看,可以看见DG.register_customized_layer、DG.build_dependency、pruning_group = DG.get_pruning_group和pruning_group.exec()。

# DG.register_customized_layer

DG.register_customized_layer用于注册“裁剪方式”。“裁剪方式”与Pytorch中的层一一对应

DG = tp.DependencyGraph()内部已经注册了一些层的默认裁剪方式,包括卷积层和线性层等:

PrunerBox = {
    ops.OPTYPE.CONV: ConvPruner(),
    ops.OPTYPE.LINEAR: LinearPruner(),
    ops.OPTYPE.BN: BatchnormPruner(),
    ops.OPTYPE.DEPTHWISE_CONV: DepthwiseConvPruner(),
    ops.OPTYPE.PRELU: PReLUPruner(),
    ops.OPTYPE.LN: LayernormPruner(),
    ops.OPTYPE.EMBED: EmbeddingPruner(),
    ops.OPTYPE.PARAMETER: ParameterPruner(),
    ops.OPTYPE.MHA: MultiheadAttentionPruner(),
    ops.OPTYPE.LSTM: LSTMPruner()
}
1
2
3
4
5
6
7
8
9
10
11
12
_dummy_pruners = {
    ops.OPTYPE.CONCAT: ops.ConcatPruner(),
    ops.OPTYPE.SPLIT: ops.SplitPruner(),
    ops.OPTYPE.ELEMENTWISE: ops.ElementWisePruner(),
    ops.OPTYPE.CUSTOMIZED: None,
}
1
2
3
4
5
6

# DG.build_dependency

DG.build_dependency用于解析模型中层之间的调用关系,即解析torch.nn.Module.forward中的内容。

采取的方式是用一个样例输入执行推断过程,在推断过程进行trace。 具体的trace方案在DG.build_dependency._trace函数中。 简言之,就是通过torch.nn.Module.register_forward_hook注册hook,从而在每个forward函数被调用时记录下调用顺序。

具体来说,DG.build_dependency._trace函数trace出的调用顺序包括两方面:输入来自哪些层、输出到哪些层。 得知这些信息后,DG.build_dependency会调用DG._build_dependency,这个函数将每一个层与层之间的调用顺序(x层的输出到y层的输入)构建为一个tp.Dependency,加进相关层的node.dependencies中:

    def _build_dependency(self, module2node):

        # There will be a dependency between two pruning operations if they:
        # 1) connects to each other in the computational graph or
        # 2) are equivalent, i.e., applied to the same layer and works in the same way.
        # Note that for some units like BN and PReLU, pruning output channels is equivalent to pruning output_channels
        # Rule 2) is designed for this case.

        for _, node in module2node.items():
            # Rule 1) - Input connections
            for in_node in node.inputs:
                handler = self.REGISTERED_PRUNERS.get(in_node.type)
                if handler is None:
                    handler = self.CUSTOMIZED_PRUNERS[in_node.class_type]
                handler = handler.prune_out_channels

                trigger = self.REGISTERED_PRUNERS.get(node.type)
                if trigger is None:
                    trigger = self.CUSTOMIZED_PRUNERS[node.class_type]
                trigger = trigger.prune_in_channels

                dep = Dependency(
                    trigger=trigger, handler=handler, source=node, target=in_node
                )
                node.dependencies.append(dep)

            # Rule 1) - Output connections
            for out_node in node.outputs:
                trigger = self.REGISTERED_PRUNERS.get(node.type)
                if trigger is None:
                    trigger = self.CUSTOMIZED_PRUNERS[node.class_type]
                trigger = trigger.prune_out_channels

                handler = self.REGISTERED_PRUNERS.get(out_node.type)
                if handler is None:
                    handler = self.CUSTOMIZED_PRUNERS[out_node.class_type]
                handler = handler.prune_in_channels

                dep = Dependency(
                    trigger=trigger, handler=handler, source=node, target=out_node
                )
                node.dependencies.append(dep)
......
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

看tp.Dependency的输入trigger=trigger, handler=handler, source=node, target=out_node,很明显这表示:当source层的裁剪过程trigger被调用时,需要调用target层的handler。

此外,进一步看REGISTERED_PRUNERS和CUSTOMIZED_PRUNERS:

class DependencyGraph(object):

    def __init__(self):
        _dummy_pruners = {
            ops.OPTYPE.CONCAT: ops.ConcatPruner(),
            ops.OPTYPE.SPLIT: ops.SplitPruner(),
            ops.OPTYPE.ELEMENTWISE: ops.ElementWisePruner(),
            ops.OPTYPE.CUSTOMIZED: None,
        }
        self.REGISTERED_PRUNERS = function.PrunerBox.copy()  # shallow copy
        self.REGISTERED_PRUNERS.update(_dummy_pruners)
        self.CUSTOMIZED_PRUNERS = {}
        self.IGNORED_LAYERS = []
......
    
    def register_customized_layer(
        self,
        layer_type,
        layer_pruner,
    ):
        """Register a customized layer for pruning.

        Args:
            layer_type (class): the type of layer
            pruner (tp.pruner.BasePruningFunc): a pruner for the given layer type.
        """
        self.CUSTOMIZED_PRUNERS[layer_type] = layer_pruner
......
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

可以发现它们实际上都继承自tp.pruner.BasePruningFunc,PrunerBox里面的几个是已实现的_dummy_pruners里面的几个都未实现。 所以很明显,只有在DG.register_customized_layer或者内部自带的继承自tp.pruner.BasePruningFunc的类的类方法才能被作为tp.Dependency里的trigger和handler。 再看这个tp.pruner.BasePruningFunc:

class BasePruningFunc(ABC):
    TARGET_MODULES = ops.TORCH_OTHERS  # None

    def __init__(self, dim=1):
        self.dim = dim

    @abstractclassmethod
    def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]):
        raise NotImplementedError

    @abstractclassmethod
    def prune_in_channels(self, layer: nn.Module, idxs: Sequence[int]):
        raise NotImplementedError

    @abstractclassmethod
    def get_out_channels(self, layer: nn.Module):
        raise NotImplementedError

    @abstractclassmethod
    def get_in_channels(self, layer: nn.Module):
        raise NotImplementedError
        
......
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

这不就是“裁输入”和“裁输出”吗😂

所以每一层的node.dependencies中实际上包含的都是:“trigger=上一层的输出被裁剪, handler=裁剪当前层的输入”或者“trigger=下一层的输入被裁剪, handler=裁剪当前层的输出”。 看到这就清晰多了,这框架主打的自动解析依赖关系完成裁剪的功能归根结底就是以这种方式组织的。

# pruning_group = DG.get_pruning_group

在DG.build_dependency之后,模型中每个Module之间的调用关系就清楚了,于是输入任意一个要裁的“节点”(输出矩阵中的某个channel,也对应卷积层中的一个卷积核)都能知道该节点在模型中的前后依赖关系(例如该节点被裁剪导致输出少了一个channel,以之作为输入的所有层也需要相应的进行修改)。

DG.get_pruning_group就是这样一个根据输入的某个节点的裁剪方案输出整体裁剪方案的函数。

其输入为要裁的层module: nn.Module、裁剪该节点的函数pruning_fn: typing.Callable和裁哪些channelidxs: typing.Union[list, tuple]。 这里的pruning_fn虽然只是typing.Callable,但结合DG.build_dependency的解析,实际的依赖关系是由一个个tp.Dependency所记录函数间的触发关系所描述的。 所以输入的pruning_fn要想能触发依赖关系上的相关裁剪函数,其必须是tp.Dependency中已经记录过的函数,换言之,它必须要是一个继承于tp.pruner.BasePruningFunc子类的类方法,其逻辑上的功能其实是指定是裁这一层的输入还是输出。

其输出的修改方案类名为tp.DependencyGroup,其是由tp.Dependency组成的数组。 DG.get_pruning_group的本质就是从module对应的node.dependencies(DG.build_dependency中构建的依赖关系)中找出trigger=pruning_fn的那些tp.Dependency组成tp.DependencyGroup。

# pruning_group.exec()

最后就是执行这个修改DG.get_pruning_group输出的修改方案,很好理解,就是按照tp.DependencyGroup里的tp.Dependency一个个执行它们的handler就行了。

帮助我们改善此页面!
创建于: 2023-03-09 03:43:43

更新于: 2023-03-10 13:52:23