diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/EdgeProperty.java b/tez-api/src/main/java/org/apache/tez/dag/api/EdgeProperty.java index 1850060ece..bb0b7f1e48 100644 --- a/tez-api/src/main/java/org/apache/tez/dag/api/EdgeProperty.java +++ b/tez-api/src/main/java/org/apache/tez/dag/api/EdgeProperty.java @@ -261,6 +261,14 @@ public EdgeManagerPluginDescriptor getEdgeManagerDescriptor() { return edgeManagerDescriptor; } + /** + * Returns a new EdgeProperty with the given EdgeManagerPluginDescriptor. + */ + public EdgeProperty withDescriptor(EdgeManagerPluginDescriptor newDescriptor) { + return new EdgeProperty(newDescriptor, this.dataMovementType, this.dataSourceType, + this.schedulingType, this.outputDescriptor, this.inputDescriptor); + } + @Override public String toString() { return "{ " + dataMovementType + " : " + inputDescriptor.getClassName() diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java index 62902b8f51..71766dfa20 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java @@ -1998,7 +1998,12 @@ private void setParallelismWrapper(int parallelism, VertexLocationHint vertexLoc Vertex sourceVertex = appContext.getCurrentDAG().getVertex(entry.getKey()); Edge edge = sourceVertices.get(sourceVertex); try { - edge.setEdgeProperty(entry.getValue()); + if (edge != null) { + edge.setEdgeProperty(entry.getValue()); + } else { + LOG.warn("Edge is null, sourceVertex = {}, entry.getValue() = {}", + sourceVertex, entry.getValue()); + } } catch (Exception e) { throw new TezUncheckedException(e); } diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java index b05c45ad96..5231e8bb65 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java @@ -27,6 +27,7 @@ import org.apache.tez.dag.api.EdgeManagerPluginContext; import org.apache.tez.dag.api.EdgeManagerPluginDescriptor; import org.apache.tez.dag.api.EdgeManagerPluginOnDemand; +import org.apache.tez.dag.api.EdgeProperty; import org.apache.tez.dag.api.TezUncheckedException; import org.apache.tez.dag.api.UserPayload; import org.apache.tez.dag.api.VertexManagerPluginContext; @@ -52,6 +53,7 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.HashMap; /** * Starts scheduling tasks when number of completed source tasks crosses @@ -520,6 +522,30 @@ ReconfigVertexParams computeRouting() { for(Map.Entry entry : bipartiteItr) { entry.getValue().newDescriptor = descriptor; } + + // Additionally, update custom edges. + Map outputEdges = getContext().getOutputVertexEdgeProperties(); + Map updatedEdges = new HashMap<>(); + for (Map.Entry entry : outputEdges.entrySet()) { + if (entry.getValue().getDataMovementType() == EdgeProperty.DataMovementType.CUSTOM) { + // Build a new custom edge manager configuration with updated parallelism. + CustomShuffleEdgeManagerConfig customConfig = new CustomShuffleEdgeManagerConfig( + currentParallelism, finalTaskParallelism, basePartitionRange, + (remainderRangeForLastShuffler > 0 ? remainderRangeForLastShuffler : basePartitionRange)); + EdgeManagerPluginDescriptor newDescriptor = EdgeManagerPluginDescriptor.create(CustomShuffleEdgeManager.class.getName()); + newDescriptor.setUserPayload(customConfig.toUserPayload()); + + // Update the EdgeProperty with the new descriptor. + EdgeProperty updatedProp = entry.getValue().withDescriptor(newDescriptor); + updatedEdges.put(entry.getKey(), updatedProp); + } + } + + // If any custom edges were updated, propagate the new configuration. + if (!updatedEdges.isEmpty()) { + getContext().reconfigureVertex(finalTaskParallelism, null, updatedEdges); + } + ReconfigVertexParams params = new ReconfigVertexParams(finalTaskParallelism, null); return params;