Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.IOException;
import java.security.AccessControlException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -57,18 +58,33 @@

import com.google.protobuf.RpcController;
import com.google.protobuf.ServiceException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DAGClientAMProtocolBlockingPBServerImpl implements DAGClientAMProtocolBlockingPB {

private static final Logger LOG = LoggerFactory.getLogger(DAGClientAMProtocolBlockingPBServerImpl.class);

DAGClientHandler real;
final FileSystem stagingFs;

UserGroupInformation amUGI = null;
UserGroupInformation rpcUGI = null;

public DAGClientAMProtocolBlockingPBServerImpl(DAGClientHandler real, FileSystem stagingFs) {
this.real = real;
this.stagingFs = stagingFs;
try{
amUGI = UserGroupInformation.getCurrentUser();
} catch (IOException e) {
//We do not throw exception because maybe this will not be used if there is no big request
LOG.error("Exception while getting current user", e);
}
}

private UserGroupInformation getRPCUser() throws ServiceException {
//should always be null exception reflect in unit test
if (rpcUGI != null) return rpcUGI;
try {
return UserGroupInformation.getCurrentUser();
} catch (IOException e) {
Expand Down Expand Up @@ -166,18 +182,24 @@ public SubmitDAGResponseProto submitDAG(RpcController controller,
real.updateLastHeartbeatTime();
try{
if (request.hasSerializedRequestPath()) {
// need to deserialize large request from hdfs
//Here we will check userGroupInformation to see if its null, should NEVER happened but in case happened
//we will use RPC user instead to do best effort try (May still fail but no other choice)
UserGroupInformation userToReadHDFS = amUGI == null? user : amUGI;
Path requestPath = new Path(request.getSerializedRequestPath());
FileSystem fs = requestPath.getFileSystem(stagingFs.getConf());
try (FSDataInputStream fsDataInputStream = fs.open(requestPath)) {
CodedInputStream in =
CodedInputStream.newInstance(fsDataInputStream);
in.setSizeLimit(Integer.MAX_VALUE);
request = SubmitDAGRequestProto.parseFrom(in);
} catch (IOException e) {
throw wrapException(e);
}
LOG.debug("Using the user {} to get the DAG plan from HDFS", userToReadHDFS);

request = userToReadHDFS.doAs((PrivilegedExceptionAction<SubmitDAGRequestProto>) () -> {
FileSystem fs = requestPath.getFileSystem(stagingFs.getConf());
try (FSDataInputStream fsDataInputStream = fs.open(requestPath)) {
CodedInputStream in = CodedInputStream.newInstance(fsDataInputStream);
in.setSizeLimit(Integer.MAX_VALUE);
return SubmitDAGRequestProto.parseFrom(in);
} catch (IOException e) {
throw wrapException(e);
}
});
}

DAGPlan dagPlan = request.getDAGPlan();
Map<String, LocalResource> additionalResources = null;
if (request.hasAdditionalAmResources()) {
Expand All @@ -188,6 +210,9 @@ public SubmitDAGResponseProto submitDAG(RpcController controller,
return SubmitDAGResponseProto.newBuilder().setDagId(dagId).build();
} catch(IOException | TezException e) {
throw wrapException(e);
} catch (InterruptedException e){
Thread.currentThread().interrupt();
throw wrapException(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,71 @@ public void testSubmitDagInSessionWithLargeDagPlan() throws Exception {
assertEquals(lrURL.getPort(), port);
assertEquals(lrURL.getFile(), path);
}

@Test(timeout = 5000)
public void testSubmitDAGUserGroupInformation() throws Exception {
// Create a simple DAG plan and write it to a file
String dagPlanName = "test-dag";
File requestFile = tmpFolder.newFile("request-file");
TezConfiguration conf = new TezConfiguration();

DAGPlan dagPlan = DAG.create(dagPlanName)
.addVertex(Vertex.create("V", ProcessorDescriptor.create("P"), 1))
.createDag(conf, null, null, null, false);

// Write DAG plan to file
try (FileOutputStream fileOutputStream = new FileOutputStream(requestFile)) {
SubmitDAGRequestProto.newBuilder().setDAGPlan(dagPlan).build().writeTo(fileOutputStream);
}

// Setup mocks
DAGClientHandler dagClientHandler = mock(DAGClientHandler.class);
ACLManager aclManager = mock(ACLManager.class);
FileSystem mockFs = mock(FileSystem.class);
UserGroupInformation mockAmUgi = mock(UserGroupInformation.class);
UserGroupInformation mockRpcUgi = mock(UserGroupInformation.class);

// DAG request with file
SubmitDAGRequestProto request = SubmitDAGRequestProto.newBuilder()
.setSerializedRequestPath(requestFile.getAbsolutePath())
.build();

when(mockAmUgi.doAs(any(java.security.PrivilegedExceptionAction.class))).thenReturn(request);
when(mockRpcUgi.doAs(any(java.security.PrivilegedExceptionAction.class))).thenReturn(request);

// Create spy on server impl with mocked FileSystem
DAGClientAMProtocolBlockingPBServerImpl serverImpl = spy(new DAGClientAMProtocolBlockingPBServerImpl(
dagClientHandler, mockFs));

// Mock behavior
when(dagClientHandler.getACLManager()).thenReturn(aclManager);
when(aclManager.checkAMModifyAccess(any(UserGroupInformation.class))).thenReturn(true);
when(dagClientHandler.submitDAG(any(DAGPlan.class), any())).thenReturn("dag-id");
when(mockFs.getConf()).thenReturn(conf);

//Set the RPC UGI
java.lang.reflect.Field rpcUGIField = DAGClientAMProtocolBlockingPBServerImpl.class.getDeclaredField("rpcUGI");
rpcUGIField.setAccessible(true);
rpcUGIField.set(serverImpl, mockRpcUgi);

// Test Case 1: When amUGI is available
// Set the amUGI field using reflection
java.lang.reflect.Field amUGIField = DAGClientAMProtocolBlockingPBServerImpl.class.getDeclaredField("amUGI");
amUGIField.setAccessible(true);
amUGIField.set(serverImpl, mockAmUgi);

serverImpl.submitDAG(null, request);

// Verify amUGI was used for doAs
verify(mockAmUgi).doAs(any(java.security.PrivilegedExceptionAction.class));

// Test Case 2: When amUGI is null
// Set amUGI to null
amUGIField.set(serverImpl, null);

// Submit DAG with serialized path
serverImpl.submitDAG(null, request);
// Verify RPC user (mockRpcUgi) was used for doAs
verify(mockRpcUgi).doAs(any(java.security.PrivilegedExceptionAction.class));
}
}