diff --git a/R/inference-tensorflow.R b/R/inference-tensorflow.R index c93732d..6002ee3 100644 --- a/R/inference-tensorflow.R +++ b/R/inference-tensorflow.R @@ -162,7 +162,7 @@ inference_tensorflow <- function(Y, Q = -tf$einsum('nc,cn->', gamma_fixed, p_y_on_c_unorm) - p_y_on_c_norm <- tf$reshape(tf$reduce_logsumexp(p_y_on_c_unorm, 0L), shape(1,-1)) + p_y_on_c_norm <- tf$reshape(tf$reduce_logsumexp(p_y_on_c_unorm, 0L), as_tensor(shape(1,NULL))) gamma <- tf$transpose(tf$exp(p_y_on_c_unorm - p_y_on_c_norm))