diff --git a/src/worker.cc b/src/worker.cc index 2afa8b06fe..0bcf556f3f 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -297,7 +297,7 @@ int Worker::Update(int step, Param* param) { int Worker::CollectAll(int step, NeuralNet* net) { auto& layers = net->layers(); for (auto& layer : layers) { - if (layer->partition_id() == id_) { + if (layer->partition_id() == id_ && layer->unroll_index() == 0) { for (Param* p : layer->GetParams()) { Collect(step, p); }