diff --git a/minpy/array_variants/mxnet/mxnet_core.py b/minpy/array_variants/mxnet/mxnet_core.py index 6716d74..49343d6 100644 --- a/minpy/array_variants/mxnet/mxnet_core.py +++ b/minpy/array_variants/mxnet/mxnet_core.py @@ -112,6 +112,10 @@ def def_grads(prims): prims('dot').def_grad(lambda ans, a, b: lambda g: mx.nd.dot(g, b, transpose_b=True)) prims('dot').def_grad( lambda ans, a, b: lambda g: mx.nd.dot(a, g, transpose_a=True), argnum=1) + # batch_dot + prims('batch_dot').def_grad(lambda ans, a, b: lambda g: mx.nd.batch_dot(g, b, transpose_b=True)) + prims('batch_dot').def_grad( + lambda ans, a, b: lambda g: mx.nd.batch_dot(a, g, transpose_a=True), argnum=1) # non-linear prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans ** 2)) prims('exp').def_grad(lambda ans, x: lambda g: g * ans)