MXNet中,gluon.Block类和gluon.HybridBlock类,和Pytorch中的nn.Module类一样,我们通过继承Block类和HybridBlock类可以很灵活的搭建我们自己的网络模型,这里总结一下HybridBlock类使用过程中的一些注意点。

HybridBlock类和Block类的区别

  HybridBlock类继承至Block类,所以HybridBlock类有Block类的全部方法和属性。HybridBlock同时支持符号式编程和命令式编程,HybridBlock类可以调用hybridize()方法,从而可以从命令式变为符号式,从而将动态图转化为静态图,提高模型的计算性能和移植性。下面是两者的比较:

HybridBlock类 Block类
重写方法 __init__()hybrid_forward(self, F, x, *args, **kwargs) __init__()forwad(self,x,*args)
是否支持符号式
支持输入参数 位置式参数、关键字参数 只支持位置式参数
是否支持导出符号模型



  可以看出HybridBlock类除了多支持符号式编程外,和Block基本没什么区别,但是注意到支持输入参数那一栏,hybrid_forward函数还支持输入关键字参数,这点也和Block不一样,下面详细分析一下hybrid_forward的调用过程。

hybrid_forward()分析

  当我们构建一个HybridBlock类后,需要重写其|__init__()hybrid_forward()方法,而我们在源码中可以看到,当一个HybridBlock类进行forward操作时,其流程如下:

  __call__()————>forward()————>hybrid_forward()

  可以看出HybridBlock类是通过forward()方法中来调用hybrid_forward()。由于HybridBlock类中的forward()方法已经被重写过了,所以我们只需要重写hybrid_forward()就可以了,其中forward()函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
if self._active:
return self._call_cached_op(x, *args)

try:
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
self._deferred_infer_shape(x, *args)
for _, i in self.params.items():
i._finish_deferred_init()
params = {i: j.data(ctx) for i, j in self._reg_params.items()}

#!!!!!!!!!!!!!!!!!!!!!!注意这里的 params参数!!!!!!!!!!!!!!!!!!!!!!!!
return self.hybrid_forward(ndarray, x, *args, **params)

assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
#!!!!!!!!!!!!!!!!!!!!!!注意这里的 params参数!!!!!!!!!!!!!!!!!!!!!!!!
return self.hybrid_forward(symbol, x, *args, **params)

  注意到forward()函数中会将在__init__()方法中注册的参数,会作为关键字参数传递给hybrid_forward。也就是说我们在__init__()方法中用self.params.get()self.params.get_constant()注册的所有参数都会作为关键字参数传递给hybrid_forward

  下面来看一个例子,我们创建了一个gluon.HybridBlock,并且在__init__()方法中注册了一个Constant参数,我们在forward时,并没有输入anchors参数,可以看到该参数会直接传递给hybrid_forward。并且forward时会自动将所有注册的参数复制到和x相同的设备中(CPU or GPU),然后再进行运算。详细程序如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Generator(gluon.HybridBlock):

def __init__(self,alloc_size=(5, 5),**kwargs):
super(SFDAnchorGenerator, self).__init__(**kwargs)
anchors = self._generate_anchors(alloc_size)
self._key = 'anchor_1'

#注册了一个Constant参数
self.anchors = self.params.get_constant(self._key, anchors)


def _generate_anchors(self,alloc_size):
return nd.random.uniform(shape=alloc_size)


#定义的hybrid_forward中需要有一个anchors参数
def hybrid_forward(self, F, x, anchors):
print(anchors)


generator = Generator()
generator.initialize()
x = nd.random.uniform(shape=(3,3))
#可以看到,这里我并没有输入 anchors 参数,程序没有报错,而是直接打印出我们的参数
generator(x)
print(generator._reg_params)

Out:

[[0.50962436 0.9767611 0.05571469 0.6048455 0.4511592 ]
[0.7392636 0.01998767 0.03918779 0.44171092 0.28280696]
[0.9795867 0.12019657 0.35944447 0.2961402 0.48089352]
[0.11872771 0.68866116 0.31798318 0.8804759 0.41426298]
[0.9182355 0.0641475 0.21682213 0.6924721 0.5651889 ]]
<NDArray 5x5 @cpu(0)>

{'anchors': Constant sfdanchorgenerator13_anchor_1 (shape=(5, 5), dtype=<class 'numpy.float32'>)}

总结

  需要注意到,当使用HybridBlock类时,所有在__init__()方法注册的参数,均会以关键字参数的形式传递到hybrid_forward中,所以如果我们重写的hybrid_forward函数如果没有对应参数,那么程序运行时将报错。并且HybridBlock类在前向运算时,会自动将传入的关键字参数都复制到与输入参数x相同的设备上(CPU or GPU),灵活使用这一点,可以让我们更加灵活的搭建自己的模型。