diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientAMProtocolBlockingPBServerImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientAMProtocolBlockingPBServerImpl.java index 5c24a27908..dd1b89fac2 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientAMProtocolBlockingPBServerImpl.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientAMProtocolBlockingPBServerImpl.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.security.AccessControlException; +import java.security.PrivilegedExceptionAction; import java.util.List; import java.util.Map; @@ -57,18 +58,33 @@ import com.google.protobuf.RpcController; import com.google.protobuf.ServiceException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class DAGClientAMProtocolBlockingPBServerImpl implements DAGClientAMProtocolBlockingPB { + private static final Logger LOG = LoggerFactory.getLogger(DAGClientAMProtocolBlockingPBServerImpl.class); + DAGClientHandler real; final FileSystem stagingFs; + UserGroupInformation amUGI = null; + UserGroupInformation rpcUGI = null; + public DAGClientAMProtocolBlockingPBServerImpl(DAGClientHandler real, FileSystem stagingFs) { this.real = real; this.stagingFs = stagingFs; + try{ + amUGI = UserGroupInformation.getCurrentUser(); + } catch (IOException e) { + //We do not throw exception because maybe this will not be used if there is no big request + LOG.error("Exception while getting current user", e); + } } private UserGroupInformation getRPCUser() throws ServiceException { + //should always be null exception reflect in unit test + if (rpcUGI != null) return rpcUGI; try { return UserGroupInformation.getCurrentUser(); } catch (IOException e) { @@ -166,18 +182,24 @@ public SubmitDAGResponseProto submitDAG(RpcController controller, real.updateLastHeartbeatTime(); try{ if (request.hasSerializedRequestPath()) { - // need to deserialize large request from hdfs + //Here we will check userGroupInformation to see if its null, should NEVER happened but in case happened + //we will use RPC user instead to do best effort try (May still fail but no other choice) + UserGroupInformation userToReadHDFS = amUGI == null? user : amUGI; Path requestPath = new Path(request.getSerializedRequestPath()); - FileSystem fs = requestPath.getFileSystem(stagingFs.getConf()); - try (FSDataInputStream fsDataInputStream = fs.open(requestPath)) { - CodedInputStream in = - CodedInputStream.newInstance(fsDataInputStream); - in.setSizeLimit(Integer.MAX_VALUE); - request = SubmitDAGRequestProto.parseFrom(in); - } catch (IOException e) { - throw wrapException(e); - } + LOG.debug("Using the user {} to get the DAG plan from HDFS", userToReadHDFS); + + request = userToReadHDFS.doAs((PrivilegedExceptionAction) () -> { + FileSystem fs = requestPath.getFileSystem(stagingFs.getConf()); + try (FSDataInputStream fsDataInputStream = fs.open(requestPath)) { + CodedInputStream in = CodedInputStream.newInstance(fsDataInputStream); + in.setSizeLimit(Integer.MAX_VALUE); + return SubmitDAGRequestProto.parseFrom(in); + } catch (IOException e) { + throw wrapException(e); + } + }); } + DAGPlan dagPlan = request.getDAGPlan(); Map additionalResources = null; if (request.hasAdditionalAmResources()) { @@ -188,6 +210,9 @@ public SubmitDAGResponseProto submitDAG(RpcController controller, return SubmitDAGResponseProto.newBuilder().setDagId(dagId).build(); } catch(IOException | TezException e) { throw wrapException(e); + } catch (InterruptedException e){ + Thread.currentThread().interrupt(); + throw wrapException(e); } } diff --git a/tez-dag/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClientAMProtocolBlockingPBServerImpl.java b/tez-dag/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClientAMProtocolBlockingPBServerImpl.java index 5f6552ee81..807ba77129 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClientAMProtocolBlockingPBServerImpl.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClientAMProtocolBlockingPBServerImpl.java @@ -136,4 +136,71 @@ public void testSubmitDagInSessionWithLargeDagPlan() throws Exception { assertEquals(lrURL.getPort(), port); assertEquals(lrURL.getFile(), path); } + + @Test(timeout = 5000) + public void testSubmitDAGUserGroupInformation() throws Exception { + // Create a simple DAG plan and write it to a file + String dagPlanName = "test-dag"; + File requestFile = tmpFolder.newFile("request-file"); + TezConfiguration conf = new TezConfiguration(); + + DAGPlan dagPlan = DAG.create(dagPlanName) + .addVertex(Vertex.create("V", ProcessorDescriptor.create("P"), 1)) + .createDag(conf, null, null, null, false); + + // Write DAG plan to file + try (FileOutputStream fileOutputStream = new FileOutputStream(requestFile)) { + SubmitDAGRequestProto.newBuilder().setDAGPlan(dagPlan).build().writeTo(fileOutputStream); + } + + // Setup mocks + DAGClientHandler dagClientHandler = mock(DAGClientHandler.class); + ACLManager aclManager = mock(ACLManager.class); + FileSystem mockFs = mock(FileSystem.class); + UserGroupInformation mockAmUgi = mock(UserGroupInformation.class); + UserGroupInformation mockRpcUgi = mock(UserGroupInformation.class); + + // DAG request with file + SubmitDAGRequestProto request = SubmitDAGRequestProto.newBuilder() + .setSerializedRequestPath(requestFile.getAbsolutePath()) + .build(); + + when(mockAmUgi.doAs(any(java.security.PrivilegedExceptionAction.class))).thenReturn(request); + when(mockRpcUgi.doAs(any(java.security.PrivilegedExceptionAction.class))).thenReturn(request); + + // Create spy on server impl with mocked FileSystem + DAGClientAMProtocolBlockingPBServerImpl serverImpl = spy(new DAGClientAMProtocolBlockingPBServerImpl( + dagClientHandler, mockFs)); + + // Mock behavior + when(dagClientHandler.getACLManager()).thenReturn(aclManager); + when(aclManager.checkAMModifyAccess(any(UserGroupInformation.class))).thenReturn(true); + when(dagClientHandler.submitDAG(any(DAGPlan.class), any())).thenReturn("dag-id"); + when(mockFs.getConf()).thenReturn(conf); + + //Set the RPC UGI + java.lang.reflect.Field rpcUGIField = DAGClientAMProtocolBlockingPBServerImpl.class.getDeclaredField("rpcUGI"); + rpcUGIField.setAccessible(true); + rpcUGIField.set(serverImpl, mockRpcUgi); + + // Test Case 1: When amUGI is available + // Set the amUGI field using reflection + java.lang.reflect.Field amUGIField = DAGClientAMProtocolBlockingPBServerImpl.class.getDeclaredField("amUGI"); + amUGIField.setAccessible(true); + amUGIField.set(serverImpl, mockAmUgi); + + serverImpl.submitDAG(null, request); + + // Verify amUGI was used for doAs + verify(mockAmUgi).doAs(any(java.security.PrivilegedExceptionAction.class)); + + // Test Case 2: When amUGI is null + // Set amUGI to null + amUGIField.set(serverImpl, null); + + // Submit DAG with serialized path + serverImpl.submitDAG(null, request); + // Verify RPC user (mockRpcUgi) was used for doAs + verify(mockRpcUgi).doAs(any(java.security.PrivilegedExceptionAction.class)); + } }