Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,10 @@ class UnflattenColumnPlugin extends SparkCommandPlugin[UnflattenColumnArgs, Unit
val schema = frame.schema
val compositeKeyNames = arguments.columns
val compositeKeyIndices = compositeKeyNames.map(schema.columnIndex)

// run the operation
val targetSchema = UnflattenColumnFunctions.createTargetSchema(schema, compositeKeyNames)
val initialRdd = frame.rdd.groupByRows(row => row.values(compositeKeyNames))
val resultRdd = UnflattenColumnFunctions.unflattenRddByCompositeKey(compositeKeyIndices, initialRdd, targetSchema, arguments.delimiter.getOrElse(defaultDelimiter))

frame.save(new FrameRdd(targetSchema, resultRdd))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.trustedanalytics.atk.engine.frame.RowWrapper

/**
* Functions for joining pair RDDs using broadcast variables
Expand All @@ -30,36 +31,35 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali
* Perform left outer-join using a broadcast variable
*
* @param other join parameter for second data frame
*
* @return key-value RDD whose values are results of left-outer join
*/
def leftBroadcastJoin(other: RddJoinParam): RDD[Row] = {
val rightBroadcastVariable = JoinBroadcastVariable(other)
lazy val rightNullRow: Row = new GenericRow(other.frame.numColumns)

val leftJoinColumns = self.joinColumns.toList
self.frame.flatMapRows(left => {
val leftKey = left.value(self.joinColumn)
rightBroadcastVariable.get(leftKey) match {
val leftKeys = left.values(leftJoinColumns)
rightBroadcastVariable.get(leftKeys) match {
case Some(rightRowSet) => for (rightRow <- rightRowSet) yield Row.merge(left.row, rightRow)
case _ => List(Row.merge(left.row, rightNullRow.copy()))
}

})
}

/**
* Right outer-join using a broadcast variable
*
* @param other join parameter for second data frame
*
* @return key-value RDD whose values are results of right-outer join
*/
def rightBroadcastJoin(other: RddJoinParam): RDD[Row] = {
val leftBroadcastVariable = JoinBroadcastVariable(self)
lazy val leftNullRow: Row = new GenericRow(self.frame.numColumns)

val rightJoinColumns = other.joinColumns.toList
other.frame.flatMapRows(right => {
val rightKey = right.value(other.joinColumn)
leftBroadcastVariable.get(rightKey) match {
val rightKeys = right.values(rightJoinColumns)
leftBroadcastVariable.get(rightKeys) match {
case Some(leftRowSet) => for (leftRow <- leftRowSet) yield Row.merge(leftRow, right.row)
case _ => List(Row.merge(leftNullRow.copy(), right.row))
}
Expand All @@ -70,31 +70,36 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali
* Inner-join using a broadcast variable
*
* @param other join parameter for second data frame
*
* @return key-value RDD whose values are results of inner-outer join
*/
def innerBroadcastJoin(other: RddJoinParam, broadcastJoinThreshold: Long): RDD[Row] = {
val leftSizeInBytes = self.estimatedSizeInBytes.getOrElse(Long.MaxValue)
val rightSizeInBytes = other.estimatedSizeInBytes.getOrElse(Long.MaxValue)

val rowWrapper = new RowWrapper(other.frame.frameSchema)
val innerJoinedRDD = if (rightSizeInBytes <= broadcastJoinThreshold) {
val rightBroadcastVariable = JoinBroadcastVariable(other)

val rightColsToKeep = other.frame.frameSchema.dropColumns(other.joinColumns.toList).columnNames
val leftJoinColumns = self.joinColumns.toList
self.frame.flatMapRows(left => {
val leftKey = left.value(self.joinColumn)
rightBroadcastVariable.get(leftKey) match {
val leftKeys = left.values(leftJoinColumns)
rightBroadcastVariable.get(leftKeys) match {
case Some(rightRowSet) =>
for (rightRow <- rightRowSet) yield Row.merge(left.row, rightRow)
for (rightRow <- rightRowSet) yield Row.merge(left.row, new GenericRow(rowWrapper(rightRow).values(rightColsToKeep).toArray))
case _ => Set.empty[Row]
}
})
}
else if (leftSizeInBytes <= broadcastJoinThreshold) {
val leftBroadcastVariable = JoinBroadcastVariable(self)
val rightJoinColumns = other.joinColumns.toList
other.frame.flatMapRows(rightRow => {
val rightKey = rightRow.value(other.joinColumn)
leftBroadcastVariable.get(rightKey) match {
val leftColsToKeep = self.frame.frameSchema.dropColumns(self.joinColumns.toList).columnNames
val rightKeys = rightRow.values(rightJoinColumns)
leftBroadcastVariable.get(rightKeys) match {
case Some(leftRowSet) =>
for (leftRow <- leftRowSet) yield Row.merge(leftRow, rightRow.row)
for (leftRow <- leftRowSet) yield Row.merge(new GenericRow(rowWrapper(leftRow).values(leftColsToKeep).toArray), rightRow.row)
case _ => Set.empty[Row]
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ case class JoinArgs(leftFrame: JoinFrameArgs,
@ArgDoc("""The type of skewed join: 'skewedhash' or 'skewedbroadcast'""") skewedJoinType: Option[String] = None) {
require(leftFrame != null && leftFrame.frame != null, "left frame is required")
require(rightFrame != null && rightFrame.frame != null, "right frame is required")
require(leftFrame.joinColumn != null, "left join column is required")
require(rightFrame.joinColumn != null, "right join column is required")
require(leftFrame.joinColumns != null, "left join column is required")
require(rightFrame.joinColumns != null, "right join column is required")
require(how != null, "join method is required")
require(skewedJoinType.isEmpty
|| (skewedJoinType.get == "skewedhash" || skewedJoinType.get == "skewedbroadcast"),
Expand All @@ -43,6 +43,6 @@ case class JoinArgs(leftFrame: JoinFrameArgs,
* Join arguments for frame
*
* @param frame Data frame
* @param joinColumn Join column name
* @param joinColumns Join column name
*/
case class JoinFrameArgs(frame: FrameReference, joinColumn: String)
case class JoinFrameArgs(frame: FrameReference, joinColumns: List[String])
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ case class JoinBroadcastVariable(joinParam: RddJoinParam) {
// Create the broadcast variable for the join
private def createBroadcastMultiMaps(joinParam: RddJoinParam): Broadcast[MultiMap[Any, Row]] = {
//Grouping by key to ensure that duplicate keys are not split across different broadcast variables
val broadcastList = joinParam.frame.groupByRows(row => row.value(joinParam.joinColumn)).collect().toList

val broadcastList = joinParam.frame.groupByRows(row => row.values(joinParam.joinColumns.toList)).collect().toList
val broadcastMultiMap = listToMultiMap(broadcastList)
joinParam.frame.sparkContext.broadcast(broadcastMultiMap)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,23 @@ class JoinPlugin extends SparkCommandPlugin[JoinArgs, FrameReference] {
val rightFrame: SparkFrame = arguments.rightFrame.frame

//first validate join columns are valid
leftFrame.schema.validateColumnsExist(List(arguments.leftFrame.joinColumn))
rightFrame.schema.validateColumnsExist(List(arguments.rightFrame.joinColumn))
require(DataTypes.isCompatibleDataType(
leftFrame.schema.columnDataType(arguments.leftFrame.joinColumn),
rightFrame.schema.columnDataType(arguments.rightFrame.joinColumn)),
"Join columns must have compatible data types")
leftFrame.schema.validateColumnsExist(arguments.leftFrame.joinColumns)
rightFrame.schema.validateColumnsExist(arguments.rightFrame.joinColumns)

//Check left join column is compatiable with right join column
(arguments.leftFrame.joinColumns zip arguments.rightFrame.joinColumns).map {
case (leftJoinCol, rightJoinCol) => require(DataTypes.isCompatibleDataType(
leftFrame.schema.columnDataType(leftJoinCol),
rightFrame.schema.columnDataType(rightJoinCol)),
"Join columns must have compatible data types")
}

// Get estimated size of frame to determine whether to use a broadcast join
val broadcastJoinThreshold = EngineConfig.broadcastJoinThreshold

val joinedFrame = JoinRddFunctions.join(
createRDDJoinParam(leftFrame, arguments.leftFrame.joinColumn, broadcastJoinThreshold),
createRDDJoinParam(rightFrame, arguments.rightFrame.joinColumn, broadcastJoinThreshold),
createRDDJoinParam(leftFrame, arguments.leftFrame.joinColumns, broadcastJoinThreshold),
createRDDJoinParam(rightFrame, arguments.rightFrame.joinColumns, broadcastJoinThreshold),
arguments.how,
broadcastJoinThreshold,
arguments.skewedJoinType
Expand All @@ -95,14 +99,14 @@ class JoinPlugin extends SparkCommandPlugin[JoinArgs, FrameReference] {

//Create parameters for join
private def createRDDJoinParam(frame: SparkFrame,
joinColumn: String,
joinColumns: Seq[String],
broadcastJoinThreshold: Long): RddJoinParam = {
val frameSize = if (broadcastJoinThreshold > 0) frame.sizeInBytes else None
val estimatedRddSize = frameSize match {
case Some(size) => Some((size * EngineConfig.frameCompressionRatio).toLong)
case _ => None
}
RddJoinParam(frame.rdd, joinColumn, estimatedRddSize)
RddJoinParam(frame.rdd, joinColumns, estimatedRddSize)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ object JoinRddFunctions extends Serializable {
* @param left join parameter for first data frame
* @param right join parameter for second data frame
* @param broadcastJoinThreshold use broadcast variable for join if size of one of the data frames is below threshold
*
* @return Joined RDD
*/
def innerJoin(left: RddJoinParam,
Expand All @@ -88,12 +87,31 @@ object JoinRddFunctions extends Serializable {
val rightFrame = right.frame.toDataFrame
val joinedFrame = leftFrame.join(
rightFrame,
leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn))
left.joinColumns
)
joinedFrame.rdd
}
}

/**
* expression maker helps for generating conditions to check when join invoked with composite keys
*
* @param leftFrame left data frame
* @param rightFrame rigth data frame
* @param leftJoinCols list of left frame column names used in join
* @param rightJoinCols list of right frame column name used in join
* @return
*/
def expressionMaker(leftFrame: DataFrame, rightFrame: DataFrame, leftJoinCols: Seq[String], rightJoinCols: Seq[String]): Column = {
val columnsTuple = leftJoinCols.zip(rightJoinCols)

def makeExpression(leftCol: String, rightCol: String): Column = {
leftFrame(leftCol).equalTo(rightFrame(rightCol))
}
val expression = columnsTuple.map { case (lc, rc) => makeExpression(lc, rc) }.reduce(_ && _)
expression
}

/**
* Perform full-outer join
*
Expand All @@ -102,14 +120,14 @@ object JoinRddFunctions extends Serializable {
*
* @param left join parameter for first data frame
* @param right join parameter for second data frame
*
* @return Joined RDD
*/
def fullOuterJoin(left: RddJoinParam, right: RddJoinParam): RDD[Row] = {
val leftFrame = left.frame.toDataFrame
val rightFrame = right.frame.toDataFrame
val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns)
val joinedFrame = leftFrame.join(rightFrame,
leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)),
expression,
joinType = "fullouter"
)
joinedFrame.rdd
Expand All @@ -123,7 +141,6 @@ object JoinRddFunctions extends Serializable {
* @param left join parameter for first data frame
* @param right join parameter for second data frame
* @param broadcastJoinThreshold use broadcast variable for join if size of first data frame is below threshold
*
* @return Joined RDD
*/
def rightOuterJoin(left: RddJoinParam,
Expand All @@ -138,8 +155,9 @@ object JoinRddFunctions extends Serializable {
case _ =>
val leftFrame = left.frame.toDataFrame
val rightFrame = right.frame.toDataFrame
val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns)
val joinedFrame = leftFrame.join(rightFrame,
leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)),
expression,
joinType = "right"
)
joinedFrame.rdd
Expand All @@ -154,7 +172,6 @@ object JoinRddFunctions extends Serializable {
* @param left join parameter for first data frame
* @param right join parameter for second data frame
* @param broadcastJoinThreshold use broadcast variable for join if size of second data frame is below threshold
*
* @return Joined RDD
*/
def leftOuterJoin(left: RddJoinParam,
Expand All @@ -167,10 +184,12 @@ object JoinRddFunctions extends Serializable {
case _ =>
val leftFrame = left.frame.toDataFrame
val rightFrame = right.frame.toDataFrame
val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns)
val joinedFrame = leftFrame.join(rightFrame,
leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)),
expression,
joinType = "left"
)

joinedFrame.rdd
}
}
Expand All @@ -193,13 +212,13 @@ object JoinRddFunctions extends Serializable {
how match {
case "outer" => {
val mergedRdd = mergeJoinColumns(joinedRdd, left, right)
dropRightJoinColumn(mergedRdd, left, right)
dropRightJoinColumn(mergedRdd, left, right, how)
}
case "right" => {
dropLeftJoinColumn(joinedRdd, left, right)
}
case _ => {
dropRightJoinColumn(joinedRdd, left, right)
dropRightJoinColumn(joinedRdd, left, right, how)
}
}
}
Expand All @@ -217,19 +236,21 @@ object JoinRddFunctions extends Serializable {
def mergeJoinColumns(joinedRdd: RDD[Row],
left: RddJoinParam,
right: RddJoinParam): RDD[Row] = {

val leftSchema = left.frame.frameSchema
val rightSchema = right.frame.frameSchema

val leftJoinIndex = leftSchema.columnIndex(left.joinColumn)
val rightJoinIndex = rightSchema.columnIndex(right.joinColumn) + leftSchema.columns.size
val leftJoinIndices = leftSchema.columnIndices(left.joinColumns)
val rightJoinIndices = rightSchema.columnIndices(right.joinColumns).map(rightindex => rightindex + leftSchema.columns.size)

joinedRdd.map(row => {
val leftKey = row.get(leftJoinIndex)
val rightKey = row.get(rightJoinIndex)
val newLeftKey = if (leftKey == null) rightKey else leftKey

val rowArray = row.toSeq.toArray
rowArray(leftJoinIndex) = newLeftKey
leftJoinIndices.zip(rightJoinIndices).map {
case (leftIndex, rightIndex) => {
if (row.get(leftIndex) == null) {
rowArray(leftIndex) = row.get(rightIndex)
}
}
}
new GenericRow(rowArray)
})
}
Expand All @@ -249,16 +270,11 @@ object JoinRddFunctions extends Serializable {
right: RddJoinParam): FrameRdd = {
val leftSchema = left.frame.frameSchema
val rightSchema = right.frame.frameSchema

// Get unique name for left join column to drop
val oldSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns))
val dropColumnName = oldSchema.column(leftSchema.columnIndex(left.joinColumn)).name

// Create new schema
val newLeftSchema = leftSchema.renameColumn(left.joinColumn, dropColumnName)
val newSchema = FrameSchema(Schema.join(newLeftSchema.columns, rightSchema.columns))

new FrameRdd(newSchema, joinedRdd)
val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns))
val frameRdd = new FrameRdd(newSchema, joinedRdd)
val leftColIndices = leftSchema.columnIndices(left.joinColumns)
val leftColNames = leftColIndices.map(colindex => newSchema.column(colindex).name)
frameRdd.dropColumns(leftColNames.toList)
}

/**
Expand All @@ -273,18 +289,24 @@ object JoinRddFunctions extends Serializable {
*/
def dropRightJoinColumn(joinedRdd: RDD[Row],
left: RddJoinParam,
right: RddJoinParam): FrameRdd = {
right: RddJoinParam,
how: String): FrameRdd = {

val leftSchema = left.frame.frameSchema
val rightSchema = right.frame.frameSchema

// Get unique name for right join column to drop
val oldSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns))
val dropColumnName = oldSchema.column(leftSchema.columns.size + rightSchema.columnIndex(right.joinColumn)).name

// Create new schema
val newRightSchema = rightSchema.renameColumn(right.joinColumn, dropColumnName)
val newSchema = FrameSchema(Schema.join(leftSchema.columns, newRightSchema.columns))

new FrameRdd(newSchema, joinedRdd)
if (how == "inner") {
val newRightSchema = rightSchema.dropColumns(right.joinColumns.toList)
val newSchema = FrameSchema(Schema.join(leftSchema.columns, newRightSchema.columns))
new FrameRdd(newSchema, joinedRdd)
}
else {
val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns))
val frameRdd = new FrameRdd(newSchema, joinedRdd)
val rightColIndices = rightSchema.columnIndices(right.joinColumns).map(rightindex => leftSchema.columns.size + rightindex)
val rightColNames = rightColIndices.map(colindex => newSchema.column(colindex).name)
frameRdd.dropColumns(rightColNames.toList)
}
}
}
Loading