Fork me on GitHub

mxnet学习心得【1】

mxnet中最基本的元素有两种:mx.ndarraymx.symbol.前者是mxnet版的numpy.ndarray.后者是符号变量。mxnet是由基于这两个元素的两个项目融合而成,因此取名为mxnet,意为mix net.

理解symbol类

mx.symbol本身只是一个占位符,在计算之前都没有进行赋值。这些符号变量可以通过mxnet支持的operator进行运算,比如mxnet重载了加减乘除,他们都是二元运算符。可以将每个operator看成是一个函数,有的函数是一个输入,有和函数是两个输入。输入也可以是一个到多个。

理解symbol类的operator

Deep learning中最重要的层在mxnet中也是一个operator,通过operator可以产生新的symbol。可以认为operator是一个包含在symbol中的对symbol中包含的symbol进行运算的操作符,在symbol中可以有operator也可以没有。输入数据或上一层的输出结果就是成为了这个函数的自变量,其输出是下一层的输入。如果operator是层的话,这个函数还会引入自带的参数symbol(我将其理解为默认的symbol)weightbias。所有的这些,每进行一次operator,都可以用print(net.tojson)查看网络的结构.

执行symbol

当我们将symbol设计好之后,要对数据进行赋值才能执行。对数据赋值的是bind方法。这个方法有一个设备参数用来指定执行symbol的设备(cpu/gpu),和一个dict参数来指定所有的符号对应的值。这里的值的类型需要是mx.ndarry。这个方法会返回一个executor对象。调用executor.forward()就会执行symbol中对应的计算,并将返回值存储在executor中,返回值的取出方法为executor.output[0].asnumpy()

Reference:
[1] https://github.com/dmlc/mxnet-notebooks

No pain, No gain