diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala index 0cff524f7f..d3ce5c06d1 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala @@ -65,12 +65,10 @@ class UnflattenColumnPlugin extends SparkCommandPlugin[UnflattenColumnArgs, Unit val schema = frame.schema val compositeKeyNames = arguments.columns val compositeKeyIndices = compositeKeyNames.map(schema.columnIndex) - // run the operation val targetSchema = UnflattenColumnFunctions.createTargetSchema(schema, compositeKeyNames) val initialRdd = frame.rdd.groupByRows(row => row.values(compositeKeyNames)) val resultRdd = UnflattenColumnFunctions.unflattenRddByCompositeKey(compositeKeyIndices, initialRdd, targetSchema, arguments.delimiter.getOrElse(defaultDelimiter)) - frame.save(new FrameRdd(targetSchema, resultRdd)) } diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala index 6982684c61..16b87baf96 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala @@ -20,6 +20,7 @@ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.trustedanalytics.atk.engine.frame.RowWrapper /** * Functions for joining pair RDDs using broadcast variables @@ -30,19 +31,19 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Perform left outer-join using a broadcast variable * * @param other join parameter for second data frame - * * @return key-value RDD whose values are results of left-outer join */ def leftBroadcastJoin(other: RddJoinParam): RDD[Row] = { val rightBroadcastVariable = JoinBroadcastVariable(other) lazy val rightNullRow: Row = new GenericRow(other.frame.numColumns) - + val leftJoinColumns = self.joinColumns.toList self.frame.flatMapRows(left => { - val leftKey = left.value(self.joinColumn) - rightBroadcastVariable.get(leftKey) match { + val leftKeys = left.values(leftJoinColumns) + rightBroadcastVariable.get(leftKeys) match { case Some(rightRowSet) => for (rightRow <- rightRowSet) yield Row.merge(left.row, rightRow) case _ => List(Row.merge(left.row, rightNullRow.copy())) } + }) } @@ -50,16 +51,15 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Right outer-join using a broadcast variable * * @param other join parameter for second data frame - * * @return key-value RDD whose values are results of right-outer join */ def rightBroadcastJoin(other: RddJoinParam): RDD[Row] = { val leftBroadcastVariable = JoinBroadcastVariable(self) lazy val leftNullRow: Row = new GenericRow(self.frame.numColumns) - + val rightJoinColumns = other.joinColumns.toList other.frame.flatMapRows(right => { - val rightKey = right.value(other.joinColumn) - leftBroadcastVariable.get(rightKey) match { + val rightKeys = right.values(rightJoinColumns) + leftBroadcastVariable.get(rightKeys) match { case Some(leftRowSet) => for (leftRow <- leftRowSet) yield Row.merge(leftRow, right.row) case _ => List(Row.merge(leftNullRow.copy(), right.row)) } @@ -70,31 +70,36 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Inner-join using a broadcast variable * * @param other join parameter for second data frame - * * @return key-value RDD whose values are results of inner-outer join */ def innerBroadcastJoin(other: RddJoinParam, broadcastJoinThreshold: Long): RDD[Row] = { val leftSizeInBytes = self.estimatedSizeInBytes.getOrElse(Long.MaxValue) val rightSizeInBytes = other.estimatedSizeInBytes.getOrElse(Long.MaxValue) + val rowWrapper = new RowWrapper(other.frame.frameSchema) val innerJoinedRDD = if (rightSizeInBytes <= broadcastJoinThreshold) { val rightBroadcastVariable = JoinBroadcastVariable(other) + + val rightColsToKeep = other.frame.frameSchema.dropColumns(other.joinColumns.toList).columnNames + val leftJoinColumns = self.joinColumns.toList self.frame.flatMapRows(left => { - val leftKey = left.value(self.joinColumn) - rightBroadcastVariable.get(leftKey) match { + val leftKeys = left.values(leftJoinColumns) + rightBroadcastVariable.get(leftKeys) match { case Some(rightRowSet) => - for (rightRow <- rightRowSet) yield Row.merge(left.row, rightRow) + for (rightRow <- rightRowSet) yield Row.merge(left.row, new GenericRow(rowWrapper(rightRow).values(rightColsToKeep).toArray)) case _ => Set.empty[Row] } }) } else if (leftSizeInBytes <= broadcastJoinThreshold) { val leftBroadcastVariable = JoinBroadcastVariable(self) + val rightJoinColumns = other.joinColumns.toList other.frame.flatMapRows(rightRow => { - val rightKey = rightRow.value(other.joinColumn) - leftBroadcastVariable.get(rightKey) match { + val leftColsToKeep = self.frame.frameSchema.dropColumns(self.joinColumns.toList).columnNames + val rightKeys = rightRow.values(rightJoinColumns) + leftBroadcastVariable.get(rightKeys) match { case Some(leftRowSet) => - for (leftRow <- leftRowSet) yield Row.merge(leftRow, rightRow.row) + for (leftRow <- leftRowSet) yield Row.merge(new GenericRow(rowWrapper(leftRow).values(leftColsToKeep).toArray), rightRow.row) case _ => Set.empty[Row] } }) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala index b903fd7727..81ec4a5ac4 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala @@ -31,8 +31,8 @@ case class JoinArgs(leftFrame: JoinFrameArgs, @ArgDoc("""The type of skewed join: 'skewedhash' or 'skewedbroadcast'""") skewedJoinType: Option[String] = None) { require(leftFrame != null && leftFrame.frame != null, "left frame is required") require(rightFrame != null && rightFrame.frame != null, "right frame is required") - require(leftFrame.joinColumn != null, "left join column is required") - require(rightFrame.joinColumn != null, "right join column is required") + require(leftFrame.joinColumns != null, "left join column is required") + require(rightFrame.joinColumns != null, "right join column is required") require(how != null, "join method is required") require(skewedJoinType.isEmpty || (skewedJoinType.get == "skewedhash" || skewedJoinType.get == "skewedbroadcast"), @@ -43,6 +43,6 @@ case class JoinArgs(leftFrame: JoinFrameArgs, * Join arguments for frame * * @param frame Data frame - * @param joinColumn Join column name + * @param joinColumns Join column name */ -case class JoinFrameArgs(frame: FrameReference, joinColumn: String) +case class JoinFrameArgs(frame: FrameReference, joinColumns: List[String]) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala index b88606c6a8..3d6a445eff 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala @@ -52,8 +52,7 @@ case class JoinBroadcastVariable(joinParam: RddJoinParam) { // Create the broadcast variable for the join private def createBroadcastMultiMaps(joinParam: RddJoinParam): Broadcast[MultiMap[Any, Row]] = { //Grouping by key to ensure that duplicate keys are not split across different broadcast variables - val broadcastList = joinParam.frame.groupByRows(row => row.value(joinParam.joinColumn)).collect().toList - + val broadcastList = joinParam.frame.groupByRows(row => row.values(joinParam.joinColumns.toList)).collect().toList val broadcastMultiMap = listToMultiMap(broadcastList) joinParam.frame.sparkContext.broadcast(broadcastMultiMap) } diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala index 8feb573fec..df24e2d02a 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala @@ -69,19 +69,23 @@ class JoinPlugin extends SparkCommandPlugin[JoinArgs, FrameReference] { val rightFrame: SparkFrame = arguments.rightFrame.frame //first validate join columns are valid - leftFrame.schema.validateColumnsExist(List(arguments.leftFrame.joinColumn)) - rightFrame.schema.validateColumnsExist(List(arguments.rightFrame.joinColumn)) - require(DataTypes.isCompatibleDataType( - leftFrame.schema.columnDataType(arguments.leftFrame.joinColumn), - rightFrame.schema.columnDataType(arguments.rightFrame.joinColumn)), - "Join columns must have compatible data types") + leftFrame.schema.validateColumnsExist(arguments.leftFrame.joinColumns) + rightFrame.schema.validateColumnsExist(arguments.rightFrame.joinColumns) + + //Check left join column is compatiable with right join column + (arguments.leftFrame.joinColumns zip arguments.rightFrame.joinColumns).map { + case (leftJoinCol, rightJoinCol) => require(DataTypes.isCompatibleDataType( + leftFrame.schema.columnDataType(leftJoinCol), + rightFrame.schema.columnDataType(rightJoinCol)), + "Join columns must have compatible data types") + } // Get estimated size of frame to determine whether to use a broadcast join val broadcastJoinThreshold = EngineConfig.broadcastJoinThreshold val joinedFrame = JoinRddFunctions.join( - createRDDJoinParam(leftFrame, arguments.leftFrame.joinColumn, broadcastJoinThreshold), - createRDDJoinParam(rightFrame, arguments.rightFrame.joinColumn, broadcastJoinThreshold), + createRDDJoinParam(leftFrame, arguments.leftFrame.joinColumns, broadcastJoinThreshold), + createRDDJoinParam(rightFrame, arguments.rightFrame.joinColumns, broadcastJoinThreshold), arguments.how, broadcastJoinThreshold, arguments.skewedJoinType @@ -95,14 +99,14 @@ class JoinPlugin extends SparkCommandPlugin[JoinArgs, FrameReference] { //Create parameters for join private def createRDDJoinParam(frame: SparkFrame, - joinColumn: String, + joinColumns: Seq[String], broadcastJoinThreshold: Long): RddJoinParam = { val frameSize = if (broadcastJoinThreshold > 0) frame.sizeInBytes else None val estimatedRddSize = frameSize match { case Some(size) => Some((size * EngineConfig.frameCompressionRatio).toLong) case _ => None } - RddJoinParam(frame.rdd, joinColumn, estimatedRddSize) + RddJoinParam(frame.rdd, joinColumns, estimatedRddSize) } } diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala index 5d980b83dd..6ef485bd2d 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala @@ -69,7 +69,6 @@ object JoinRddFunctions extends Serializable { * @param left join parameter for first data frame * @param right join parameter for second data frame * @param broadcastJoinThreshold use broadcast variable for join if size of one of the data frames is below threshold - * * @return Joined RDD */ def innerJoin(left: RddJoinParam, @@ -88,12 +87,31 @@ object JoinRddFunctions extends Serializable { val rightFrame = right.frame.toDataFrame val joinedFrame = leftFrame.join( rightFrame, - leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)) + left.joinColumns ) joinedFrame.rdd } } + /** + * expression maker helps for generating conditions to check when join invoked with composite keys + * + * @param leftFrame left data frame + * @param rightFrame rigth data frame + * @param leftJoinCols list of left frame column names used in join + * @param rightJoinCols list of right frame column name used in join + * @return + */ + def expressionMaker(leftFrame: DataFrame, rightFrame: DataFrame, leftJoinCols: Seq[String], rightJoinCols: Seq[String]): Column = { + val columnsTuple = leftJoinCols.zip(rightJoinCols) + + def makeExpression(leftCol: String, rightCol: String): Column = { + leftFrame(leftCol).equalTo(rightFrame(rightCol)) + } + val expression = columnsTuple.map { case (lc, rc) => makeExpression(lc, rc) }.reduce(_ && _) + expression + } + /** * Perform full-outer join * @@ -102,14 +120,14 @@ object JoinRddFunctions extends Serializable { * * @param left join parameter for first data frame * @param right join parameter for second data frame - * * @return Joined RDD */ def fullOuterJoin(left: RddJoinParam, right: RddJoinParam): RDD[Row] = { val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame + val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns) val joinedFrame = leftFrame.join(rightFrame, - leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)), + expression, joinType = "fullouter" ) joinedFrame.rdd @@ -123,7 +141,6 @@ object JoinRddFunctions extends Serializable { * @param left join parameter for first data frame * @param right join parameter for second data frame * @param broadcastJoinThreshold use broadcast variable for join if size of first data frame is below threshold - * * @return Joined RDD */ def rightOuterJoin(left: RddJoinParam, @@ -138,8 +155,9 @@ object JoinRddFunctions extends Serializable { case _ => val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame + val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns) val joinedFrame = leftFrame.join(rightFrame, - leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)), + expression, joinType = "right" ) joinedFrame.rdd @@ -154,7 +172,6 @@ object JoinRddFunctions extends Serializable { * @param left join parameter for first data frame * @param right join parameter for second data frame * @param broadcastJoinThreshold use broadcast variable for join if size of second data frame is below threshold - * * @return Joined RDD */ def leftOuterJoin(left: RddJoinParam, @@ -167,10 +184,12 @@ object JoinRddFunctions extends Serializable { case _ => val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame + val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns) val joinedFrame = leftFrame.join(rightFrame, - leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)), + expression, joinType = "left" ) + joinedFrame.rdd } } @@ -193,13 +212,13 @@ object JoinRddFunctions extends Serializable { how match { case "outer" => { val mergedRdd = mergeJoinColumns(joinedRdd, left, right) - dropRightJoinColumn(mergedRdd, left, right) + dropRightJoinColumn(mergedRdd, left, right, how) } case "right" => { dropLeftJoinColumn(joinedRdd, left, right) } case _ => { - dropRightJoinColumn(joinedRdd, left, right) + dropRightJoinColumn(joinedRdd, left, right, how) } } } @@ -217,19 +236,21 @@ object JoinRddFunctions extends Serializable { def mergeJoinColumns(joinedRdd: RDD[Row], left: RddJoinParam, right: RddJoinParam): RDD[Row] = { + val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema - - val leftJoinIndex = leftSchema.columnIndex(left.joinColumn) - val rightJoinIndex = rightSchema.columnIndex(right.joinColumn) + leftSchema.columns.size + val leftJoinIndices = leftSchema.columnIndices(left.joinColumns) + val rightJoinIndices = rightSchema.columnIndices(right.joinColumns).map(rightindex => rightindex + leftSchema.columns.size) joinedRdd.map(row => { - val leftKey = row.get(leftJoinIndex) - val rightKey = row.get(rightJoinIndex) - val newLeftKey = if (leftKey == null) rightKey else leftKey - val rowArray = row.toSeq.toArray - rowArray(leftJoinIndex) = newLeftKey + leftJoinIndices.zip(rightJoinIndices).map { + case (leftIndex, rightIndex) => { + if (row.get(leftIndex) == null) { + rowArray(leftIndex) = row.get(rightIndex) + } + } + } new GenericRow(rowArray) }) } @@ -249,16 +270,11 @@ object JoinRddFunctions extends Serializable { right: RddJoinParam): FrameRdd = { val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema - - // Get unique name for left join column to drop - val oldSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) - val dropColumnName = oldSchema.column(leftSchema.columnIndex(left.joinColumn)).name - - // Create new schema - val newLeftSchema = leftSchema.renameColumn(left.joinColumn, dropColumnName) - val newSchema = FrameSchema(Schema.join(newLeftSchema.columns, rightSchema.columns)) - - new FrameRdd(newSchema, joinedRdd) + val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) + val frameRdd = new FrameRdd(newSchema, joinedRdd) + val leftColIndices = leftSchema.columnIndices(left.joinColumns) + val leftColNames = leftColIndices.map(colindex => newSchema.column(colindex).name) + frameRdd.dropColumns(leftColNames.toList) } /** @@ -273,18 +289,24 @@ object JoinRddFunctions extends Serializable { */ def dropRightJoinColumn(joinedRdd: RDD[Row], left: RddJoinParam, - right: RddJoinParam): FrameRdd = { + right: RddJoinParam, + how: String): FrameRdd = { + val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema - // Get unique name for right join column to drop - val oldSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) - val dropColumnName = oldSchema.column(leftSchema.columns.size + rightSchema.columnIndex(right.joinColumn)).name - // Create new schema - val newRightSchema = rightSchema.renameColumn(right.joinColumn, dropColumnName) - val newSchema = FrameSchema(Schema.join(leftSchema.columns, newRightSchema.columns)) - - new FrameRdd(newSchema, joinedRdd) + if (how == "inner") { + val newRightSchema = rightSchema.dropColumns(right.joinColumns.toList) + val newSchema = FrameSchema(Schema.join(leftSchema.columns, newRightSchema.columns)) + new FrameRdd(newSchema, joinedRdd) + } + else { + val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) + val frameRdd = new FrameRdd(newSchema, joinedRdd) + val rightColIndices = rightSchema.columnIndices(right.joinColumns).map(rightindex => leftSchema.columns.size + rightindex) + val rightColNames = rightColIndices.map(colindex => newSchema.column(colindex).name) + frameRdd.dropColumns(rightColNames.toList) + } } } diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/RddJoinParam.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/RddJoinParam.scala index 325a522b40..20673e41d4 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/RddJoinParam.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/RddJoinParam.scala @@ -22,14 +22,14 @@ import org.apache.spark.frame.FrameRdd * Join parameters for RDD * * @param frame Frame used for join - * @param joinColumn Join column name + * @param joinColumns Join column name * @param estimatedSizeInBytes Optional estimated size of RDD in bytes used to determine whether to use a broadcast join */ case class RddJoinParam(frame: FrameRdd, - joinColumn: String, + joinColumns: Seq[String], estimatedSizeInBytes: Option[Long] = None) { require(frame != null, "join frame is required") - require(joinColumn != null, "join column is required") + require(joinColumns != null, "join column(s) are required") require(estimatedSizeInBytes.isEmpty || estimatedSizeInBytes.get > 0, "Estimated rdd size in bytes should be empty or greater than zero") } diff --git a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala index 09ac8614d8..b65f3788e9 100644 --- a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala +++ b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala @@ -41,38 +41,39 @@ class JoinBroadcastVariableITest extends TestingSparkContextFlatSpec with Matche "JoinBroadcastVariable" should "create a single broadcast variable when RDD size is less than 2GB" in { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(idCountryNames)) - val joinParam = RddJoinParam(countryNames, "col_0", Some(150)) + val joinParam = RddJoinParam(countryNames, Seq("col_0"), Some(150)) val broadcastVariable = JoinBroadcastVariable(joinParam) broadcastVariable.broadcastMultiMap.value.size should equal(5) - broadcastVariable.get(1).get should contain theSameElementsAs Set(idCountryNames(0), idCountryNames(1)) - broadcastVariable.get(2).get should contain theSameElementsAs Set(idCountryNames(2)) - broadcastVariable.get(3).get should contain theSameElementsAs Set(idCountryNames(3)) - broadcastVariable.get(4).get should contain theSameElementsAs Set(idCountryNames(4)) - broadcastVariable.get(6).get should contain theSameElementsAs Set(idCountryNames(5)) - broadcastVariable.get(8).isDefined should equal(false) + broadcastVariable.get(List(1)).get should contain theSameElementsAs Set(idCountryNames(0), idCountryNames(1)) + broadcastVariable.get(List(2)).get should contain theSameElementsAs Set(idCountryNames(2)) + broadcastVariable.get(List(3)).get should contain theSameElementsAs Set(idCountryNames(3)) + broadcastVariable.get(List(4)).get should contain theSameElementsAs Set(idCountryNames(4)) + broadcastVariable.get(List(6)).get should contain theSameElementsAs Set(idCountryNames(5)) + broadcastVariable.get(List(8)).isDefined should equal(false) } + "JoinBroadcastVariable" should "create a two broadcast variables when RDD size is equals 3GB" in { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(idCountryNames)) - val joinParam = RddJoinParam(countryNames, "col_0", Some(3L * 1024 * 1024 * 1024)) + val joinParam = RddJoinParam(countryNames, Seq("col_0"), Some(3L * 1024 * 1024 * 1024)) val broadcastVariable = JoinBroadcastVariable(joinParam) - broadcastVariable.get(1).get should contain theSameElementsAs Set(idCountryNames(0), idCountryNames(1)) - broadcastVariable.get(2).get should contain theSameElementsAs Set(idCountryNames(2)) - broadcastVariable.get(3).get should contain theSameElementsAs Set(idCountryNames(3)) - broadcastVariable.get(4).get should contain theSameElementsAs Set(idCountryNames(4)) - broadcastVariable.get(6).get should contain theSameElementsAs Set(idCountryNames(5)) - broadcastVariable.get(8).isDefined should equal(false) + broadcastVariable.get(List(1)).get should contain theSameElementsAs Set(idCountryNames(0), idCountryNames(1)) + broadcastVariable.get(List(2)).get should contain theSameElementsAs Set(idCountryNames(2)) + broadcastVariable.get(List(3)).get should contain theSameElementsAs Set(idCountryNames(3)) + broadcastVariable.get(List(4)).get should contain theSameElementsAs Set(idCountryNames(4)) + broadcastVariable.get(List(6)).get should contain theSameElementsAs Set(idCountryNames(5)) + broadcastVariable.get(List(8)).isDefined should equal(false) } "JoinBroadcastVariable" should "create an empty broadcast variable" in { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(List.empty[Row])) - val joinParam = RddJoinParam(countryNames, "col_0", Some(3L * 1024 * 1024 * 1024)) + val joinParam = RddJoinParam(countryNames, Seq("col_0"), Some(3L * 1024 * 1024 * 1024)) val broadcastVariable = JoinBroadcastVariable(joinParam) @@ -86,7 +87,7 @@ class JoinBroadcastVariableITest extends TestingSparkContextFlatSpec with Matche "JoinBroadcastVariable" should "throw an Exception if column does not exist in frame" in { intercept[Exception] { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(idCountryNames)) - val joinParam = RddJoinParam(countryNames, "col_bad", Some(3L * 1024 * 1024 * 1024)) + val joinParam = RddJoinParam(countryNames, Seq("col_bad"), Some(3L * 1024 * 1024 * 1024)) JoinBroadcastVariable(joinParam) } } diff --git a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala index cdcd78fe97..f0c285cf15 100644 --- a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala +++ b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala @@ -26,30 +26,32 @@ import org.trustedanalytics.atk.testutils.TestingSparkContextFlatSpec class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { // Test data has duplicate keys, matching and non-matching keys val idCountryCodes: List[Row] = List( - new GenericRow(Array[Any](1, 354)), - new GenericRow(Array[Any](2, 91)), - new GenericRow(Array[Any](2, 100)), - new GenericRow(Array[Any](3, 47)), - new GenericRow(Array[Any](4, 968)), - new GenericRow(Array[Any](5, 50))) + new GenericRow(Array[Any](1, 354, "a")), + new GenericRow(Array[Any](2, 91, "a")), + new GenericRow(Array[Any](2, 100, "b")), + new GenericRow(Array[Any](3, 47, "a")), + new GenericRow(Array[Any](4, 968, "c")), + new GenericRow(Array[Any](5, 50, "c"))) val idCountryNames: List[Row] = List( - new GenericRow(Array[Any](1, "Iceland")), - new GenericRow(Array[Any](1, "Ice-land")), - new GenericRow(Array[Any](2, "India")), - new GenericRow(Array[Any](3, "Norway")), - new GenericRow(Array[Any](4, "Oman")), - new GenericRow(Array[Any](6, "Germany")) + new GenericRow(Array[Any](1, "Iceland", "a")), + new GenericRow(Array[Any](1, "Ice-land", "a")), + new GenericRow(Array[Any](2, "India", "b")), + new GenericRow(Array[Any](3, "Norway", "a")), + new GenericRow(Array[Any](4, "Oman", "c")), + new GenericRow(Array[Any](6, "Germany", "c")) ) val codeSchema = FrameSchema(List( Column("col_0", DataTypes.int32), - Column("col_1", DataTypes.int32) + Column("col_1", DataTypes.int32), + Column("col_2", DataTypes.str) )) val countrySchema = FrameSchema(List( Column("col_0", DataTypes.int32), - Column("col_1", DataTypes.str) + Column("col_1", DataTypes.str), + Column("col_2", DataTypes.str) )) "joinRDDs" should "join two RDD with inner join" in { @@ -57,24 +59,25 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "inner") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "inner") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")) + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")) ) results should contain theSameElementsAs expectedResults @@ -84,8 +87,8 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val leftJoinParam = RddJoinParam(countryCode, "col_0", Some(150)) - val rightJoinParam = RddJoinParam(countryNames, "col_0", Some(10000)) + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0"), Some(150)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0"), Some(10000)) val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "inner") val results = resultFrame.collect() @@ -93,45 +96,102 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with inner join" in { + + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0", "col_2")), + RddJoinParam(countryNames, Seq("col_0", "col_2")), "inner") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2", DataTypes.str, 2), Column("col_1_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")) + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")) ) results should contain theSameElementsAs expectedResults } - "joinRDDs" should "join two RDD with left join" in { + "compositeJoinRDDs" should "join two RDD with inner join using broadcast variable" in { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "left") + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0", "col_2"), Some(150)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0", "col_2"), Some(10000)) + + val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "inner") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), + Column("col_2", DataTypes.str, 2), Column("col_1_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](5, 50, null, null)) + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")) + ) + + results should contain theSameElementsAs expectedResults + } + + "joinRDDs" should "join two RDD with left join" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "left") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0_L", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")), + new GenericRow(Array[Any](5, 50, "c", null, null)) ) results should contain theSameElementsAs expectedResults @@ -141,56 +201,118 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val leftJoinParam = RddJoinParam(countryCode, "col_0", Some(1500L)) - val rightJoinParam = RddJoinParam(countryNames, "col_0", Some(100L + Int.MaxValue)) + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0"), Some(1500L)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0"), Some(100L + Int.MaxValue)) // Test join wrapper function val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "left") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")), + new GenericRow(Array[Any](5, 50, "c", null, null)) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with left join" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0", "col_2")), + RddJoinParam(countryNames, Seq("col_0", "col_2")), "left") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0_L", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2_L", DataTypes.str, 2), Column("col_1_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](5, 50, null, null)) + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 91, "a", null)), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")), + new GenericRow(Array[Any](5, 50, "c", null)) ) results should contain theSameElementsAs expectedResults } - "joinRDDs" should "join two RDD with right join" in { + + "compositeJoinRDDs" should "join two RDD with left join using broadcast variable" in { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join( - RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "right") + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0", "col_2"), Some(1500L)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0", "col_2"), Some(100L + Int.MaxValue)) + + // Test join wrapper function + val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "left") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0", DataTypes.int32, 2), + Column("col_2_L", DataTypes.str, 2), Column("col_1_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](null, null, 6, "Germany")) + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 91, "a", null)), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")), + new GenericRow(Array[Any](5, 50, "c", null)) + ) + + results should contain theSameElementsAs expectedResults + } + + "joinRDDs" should "join two RDD with right join" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join( + RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "right") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_1_L", DataTypes.int32, 0), + Column("col_2_L", DataTypes.str, 1), + Column("col_0_R", DataTypes.int32, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](354, "a", 1, "Iceland", "a")), + new GenericRow(Array[Any](354, "a", 1, "Ice-land", "a")), + new GenericRow(Array[Any](91, "a", 2, "India", "b")), + new GenericRow(Array[Any](100, "b", 2, "India", "b")), + new GenericRow(Array[Any](47, "a", 3, "Norway", "a")), + new GenericRow(Array[Any](968, "c", 4, "Oman", "c")), + new GenericRow(Array[Any](null, null, 6, "Germany", "c")) ) results should contain theSameElementsAs expectedResults @@ -201,56 +323,145 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) val broadcastJoinThreshold = 1000 - val leftJoinParam = RddJoinParam(countryCode, "col_0", Some(800)) - val rightJoinParam = RddJoinParam(countryNames, "col_0", Some(4000)) + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0"), Some(800)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0"), Some(4000)) + + val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "right", broadcastJoinThreshold) + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_1_L", DataTypes.int32, 0), + Column("col_2_L", DataTypes.str, 1), + Column("col_0_R", DataTypes.int32, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](354, "a", 1, "Iceland", "a")), + new GenericRow(Array[Any](354, "a", 1, "Ice-land", "a")), + new GenericRow(Array[Any](91, "a", 2, "India", "b")), + new GenericRow(Array[Any](100, "b", 2, "India", "b")), + new GenericRow(Array[Any](47, "a", 3, "Norway", "a")), + new GenericRow(Array[Any](968, "c", 4, "Oman", "c")), + new GenericRow(Array[Any](null, null, 6, "Germany", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with right join" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join( + RddJoinParam(countryCode, Seq("col_0", "col_2")), + RddJoinParam(countryNames, Seq("col_0", "col_2")), "right") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_1_L", DataTypes.int32, 0), + Column("col_0_R", DataTypes.int32, 1), + Column("col_1_R", DataTypes.str, 2), + Column("col_2_R", DataTypes.str, 3) + )) + + val expectedResults = List( + new GenericRow(Array[Any](354, 1, "Iceland", "a")), + new GenericRow(Array[Any](354, 1, "Ice-land", "a")), + new GenericRow(Array[Any](100, 2, "India", "b")), + new GenericRow(Array[Any](47, 3, "Norway", "a")), + new GenericRow(Array[Any](968, 4, "Oman", "c")), + new GenericRow(Array[Any](null, 6, "Germany", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with right join using broadcast variable" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val broadcastJoinThreshold = 1000 + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0", "col_2"), Some(800)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0", "col_2"), Some(4000)) val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "right", broadcastJoinThreshold) val results = resultFrame.collect() + resultFrame.frameSchema.columns should equal(List( + Column("col_1_L", DataTypes.int32, 0), + Column("col_0_R", DataTypes.int32, 1), + Column("col_1_R", DataTypes.str, 2), + Column("col_2_R", DataTypes.str, 3) + )) + + val expectedResults = List( + new GenericRow(Array[Any](354, 1, "Iceland", "a")), + new GenericRow(Array[Any](354, 1, "Ice-land", "a")), + new GenericRow(Array[Any](100, 2, "India", "b")), + new GenericRow(Array[Any](47, 3, "Norway", "a")), + new GenericRow(Array[Any](968, 4, "Oman", "c")), + new GenericRow(Array[Any](null, 6, "Germany", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "joinRDDs" should "join two RDD with outer join" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "outer") + val results = resultFrame.collect() + resultFrame.frameSchema.columns should equal(List( Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](null, null, 6, "Germany")) + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")), + new GenericRow(Array[Any](5, 50, "c", null, null)), + new GenericRow(Array[Any](6, null, null, "Germany", "c")) ) results should contain theSameElementsAs expectedResults } - "joinRDDs" should "join two RDD with outer join" in { + "compositeJoinRDDs" should "join two RDD with outer join" in { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "outer") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0", "col_2")), + RddJoinParam(countryNames, Seq("col_0", "col_2")), "outer") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), + Column("col_2_L", DataTypes.str, 2), Column("col_1_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](5, 50, null, null)), - new GenericRow(Array[Any](6, null, 6, "Germany")) + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 91, "a", null)), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")), + new GenericRow(Array[Any](5, 50, "c", null)), + new GenericRow(Array[Any](6, null, "c", "Germany")) ) results should contain theSameElementsAs expectedResults @@ -261,24 +472,51 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(emptyIdCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "outer") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), RddJoinParam(countryNames, Seq("col_0")), "outer") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, null, null, "Iceland", "a")), + new GenericRow(Array[Any](1, null, null, "Ice-land", "a")), + new GenericRow(Array[Any](2, null, null, "India", "b")), + new GenericRow(Array[Any](3, null, null, "Norway", "a")), + new GenericRow(Array[Any](4, null, null, "Oman", "c")), + new GenericRow(Array[Any](6, null, null, "Germany", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeOuter join with empty left RDD" should "preserve the result from the right RDD" in { + val emptyIdCountryCodes = List.empty[Row] + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(emptyIdCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0", "col_2")), RddJoinParam(countryNames, Seq("col_0", "col_2")), "outer") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0_L", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2_L", DataTypes.str, 2), Column("col_1_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](1, null, 1, "Iceland")), - new GenericRow(Array[Any](1, null, 1, "Ice-land")), - new GenericRow(Array[Any](2, null, 2, "India")), - new GenericRow(Array[Any](3, null, 3, "Norway")), - new GenericRow(Array[Any](4, null, 4, "Oman")), - new GenericRow(Array[Any](6, null, 6, "Germany")) + new GenericRow(Array[Any](1, null, "a", "Iceland")), + new GenericRow(Array[Any](1, null, "a", "Ice-land")), + new GenericRow(Array[Any](2, null, "b", "India")), + new GenericRow(Array[Any](3, null, "a", "Norway")), + new GenericRow(Array[Any](4, null, "c", "Oman")), + new GenericRow(Array[Any](6, null, "c", "Germany")) ) results should contain theSameElementsAs expectedResults @@ -289,24 +527,25 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(emptyIdCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "outer") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "outer") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, null, null)), - new GenericRow(Array[Any](2, 91, null, null)), - new GenericRow(Array[Any](2, 100, null, null)), - new GenericRow(Array[Any](3, 47, null, null)), - new GenericRow(Array[Any](4, 968, null, null)), - new GenericRow(Array[Any](5, 50, null, null)) + new GenericRow(Array[Any](1, 354, "a", null, null)), + new GenericRow(Array[Any](2, 91, "a", null, null)), + new GenericRow(Array[Any](2, 100, "b", null, null)), + new GenericRow(Array[Any](3, 47, "a", null, null)), + new GenericRow(Array[Any](4, 968, "c", null, null)), + new GenericRow(Array[Any](5, 50, "c", null, null)) ) results should contain theSameElementsAs expectedResults @@ -327,12 +566,11 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val rddFiveHundredThousandsToOneFiftyThousands = new FrameRdd(inputSchema, sparkContext.parallelize(fiftyThousandToOneFiftyThousands)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(rddOneToMillion, "col_0"), - RddJoinParam(rddFiveHundredThousandsToOneFiftyThousands, "col_0"), "outer") + val resultFrame = JoinRddFunctions.join(RddJoinParam(rddOneToMillion, Seq("col_0")), + RddJoinParam(rddFiveHundredThousandsToOneFiftyThousands, Seq("col_0")), "outer") resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), - Column("col_0_R", DataTypes.int32, 1) + Column("col_0_L", DataTypes.int32, 0) )) resultFrame.count shouldBe 150000 } diff --git a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala index f119ec952f..4ca27d03bf 100644 --- a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala +++ b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala @@ -20,8 +20,7 @@ import java.io.File import java.util import org.trustedanalytics.atk.moduleloader.ClassLoaderAware -import org.trustedanalytics.atk.domain.frame.FrameReference -import org.trustedanalytics.atk.domain.frame.Udf +import org.trustedanalytics.atk.domain.frame.{ FrameReference, Udf } import org.trustedanalytics.atk.domain.schema.{ Column, DataTypes, Schema } import org.trustedanalytics.atk.engine.plugin.Invocation import org.trustedanalytics.atk.engine.{ SparkContextFactory, EngineConfig } @@ -82,8 +81,9 @@ object PythonRddStorage { /** * This method returns a FrameRdd after applying UDF on referencing FrameRdd + * * @param data Current referencing FrameRdd - * @param aggregateByColumnKeys List of column name(s) based on which aggregation is performed + * @param aggregateByColumnKeys List of column name(s) based on which yeahaggregation is performed * @param udf User Defined function(UDF) to apply on each row * @param udfSchema Mandatory output schema * @return FrameRdd @@ -91,11 +91,9 @@ object PythonRddStorage { def aggregateMapWith(data: FrameRdd, aggregateByColumnKeys: List[String], udf: Udf, udfSchema: Schema, sc: SparkContext): FrameRdd = { //Create a new schema which includes keys (KeyedSchema). val keyedSchema = udfSchema.copy(columns = data.frameSchema.columns(aggregateByColumnKeys) ++ udfSchema.columns) - //track key indices to fetch data during BSON decode. - val keyIndices = for (key <- aggregateByColumnKeys) yield data.frameSchema.columnIndex(key) val converter = DataTypes.parseMany(keyedSchema.columns.map(_.dataType).toArray)(_) val groupRDD = data.groupByRows(row => row.values(aggregateByColumnKeys)) - val pyRdd = aggregateRddToPyRdd(udf, groupRDD, keyIndices, sc) + val pyRdd = aggregateRddToPyRdd(udf, groupRDD, sc) val frameRdd = getRddFromPythonRdd(pyRdd, converter) FrameRdd.toFrameRdd(keyedSchema, frameRdd) } @@ -138,7 +136,8 @@ object PythonRddStorage { case y: scala.collection.mutable.Seq[_] => iterableToBsonList(y) case value => value }) - BSON.encode(obj) + val res = BSON.encode(obj) + res } ) val pyRdd = getPyRdd(udf, sc, baseRdd, predicateInBytes) @@ -147,6 +146,7 @@ object PythonRddStorage { /** * This method converts the base RDD into Python RDD which is processed by the Python VM at the server. + * * @param udf UDF provided by the user * @param baseRdd Base RDD in Array[Bytes] * @param predicateInBytes UDF in Array[Bytes] @@ -195,12 +195,12 @@ object PythonRddStorage { /** * This method encodes the raw rdd into Bson to convert into PythonRDD + * * @param udf UDF provided by user to apply on each row * @param rdd rdd(List[keys], List[Rows]) - * @param keyIndices List of key indices, used to retreive key data with result frame * @return PythonRdd */ - def aggregateRddToPyRdd(udf: Udf, rdd: RDD[(List[Any], Iterable[Row])], keyIndices: List[Int], sc: SparkContext): EnginePythonRdd[Array[Byte]] = { + def aggregateRddToPyRdd(udf: Udf, rdd: RDD[(List[Any], Iterable[Row])], sc: SparkContext): EnginePythonRdd[Array[Byte]] = { val predicateInBytes = decodePythonBase64EncodedStrToBytes(udf.function) val baseRdd: RDD[Array[Byte]] = rdd.map { case (key, rows) => { @@ -214,9 +214,10 @@ object PythonRddStorage { case value => value } }).toArray - obj.put("keyindices", keyIndices.toArray) + //obj.put("keyindices", keyIndices.toArray) obj.put("array", bsonRows) - BSON.encode(obj) + val res = BSON.encode(obj) + res } } val pyRdd = getPyRdd(udf, sc, baseRdd, predicateInBytes) @@ -249,13 +250,14 @@ object PythonRddStorage { //should be BasicBSONList containing only BasicBSONList objects val bson = BSON.decode(s) val asList = bson.get("array").asInstanceOf[BasicBSONList] - asList.map(innerList => { + val res = asList.map(innerList => { val asBsonList = innerList.asInstanceOf[BasicBSONList] asBsonList.map { case x: BasicBSONList => x.toArray case value => value }.toArray.asInstanceOf[Array[Any]] }) + res }).map(converter) resultRdd diff --git a/python-client/trustedanalytics/core/frame.py b/python-client/trustedanalytics/core/frame.py index e62e62521f..c4b0bc4b8e 100644 --- a/python-client/trustedanalytics/core/frame.py +++ b/python-client/trustedanalytics/core/frame.py @@ -1071,8 +1071,8 @@ def __inspect(self, @api @beta @arg('right', 'Frame', "Another frame to join with") - @arg('left_on', str, "Name of the column in the left frame used to match up the two frames.") - @arg('right_on', str, "Name of the column in the right frame used to match up the two frames. " + @arg('left_on', list, "Names of the columns in the left frame used to match up the two frames.") + @arg('right_on', list, "Names of the columns in the right frame used to match up the two frames. " "Default is the same as the left frame.") @arg('how', str, "How to qualify the data to be joined together. Must be one of the following: " "'left', 'right', 'inner', 'outer'. Default is 'inner'") @@ -1085,21 +1085,21 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): Create a new frame from a SQL JOIN operation with another frame. The frame on the 'left' is the currently active frame. The frame on the 'right' is another frame. - This method takes a column in the left frame and matches its values - with a column in the right frame. + This method take column(s) in the left frame and matches its values + with column(s) in the right frame. Using the default 'how' option ['inner'] will only allow data in the resultant frame if both the left and right frames have the same value - in the matching column. + in the matching column(s). Using the 'left' 'how' option will allow any data in the resultant frame if it exists in the left frame, but will allow any data from the - right frame if it has a value in its column which matches the value in - the left frame column. + right frame if it has a value in its column(s) which matches the value in + the left frame column(s). Using the 'right' option works similarly, except it keeps all the data from the right frame and only the data from the left frame when it matches. The 'outer' option provides a frame with data from both frames where the left and right frames did not have the same value in the matching - column. + column(s). Notes ----- @@ -1130,6 +1130,20 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): >>> colors = ta.Frame(ta.UploadRows([[1, 'red'], [2, 'yellow'], [3, 'green'], [4, 'blue']], [('numbers', ta.int32), ('color', str)])) -etc- + >>> country_code_rows = [[1, 354, "a"],[2, 91, "a"],[2, 100, "b"],[3, 47, "a"],[4, 968, "c"],[5, 50, "c"]] + >>> country_code_schema = [("col_0", int),("col_1", int),("col_2",str)] + -etc- + + >>> country_name_rows = [[1, "Iceland", "a"],[1, "Ice-land", "a"],[2, "India", "b"],[3, "Norway", "a"],[4, "Oman", "c"],[6, "Germany", "c"]] + >>> country_names_schema = [("col_0", int),("col_1", str),("col_2",str)] + -etc- + + >>> country_codes_frame = ta.Frame(ta.UploadRows(country_code_rows, country_code_schema)) + -etc- + + >>> country_names_frame= ta.Frame(ta.UploadRows(country_name_rows, country_names_schema)) + -etc- + Consider two frames: codes and colors @@ -1161,12 +1175,7 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): >>> j = codes.join(colors, 'numbers') - >>> j.inspect() - - - >>> j.inspect(columns=['numbers', 'color']) - [#] numbers color ==================== [0] 1 red @@ -1183,45 +1192,73 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): >>> j_left = codes.join(colors, 'numbers', how='left') - >>> j_left.inspect() - - - >>> j_left.inspect(columns=['numbers', 'color']) - - [#] numbers color - ==================== - [0] 1 red - [1] 3 green - [2] 1 red - [3] 0 None - [4] 2 yellow - [5] 1 red - [6] 5 None - [7] 3 green + [#] numbers_L color + ====================== + [0] 1 red + [1] 3 green + [2] 1 red + [3] 0 None + [4] 2 yellow + [5] 1 red + [6] 5 None + [7] 3 green + And an outer join: >>> j_outer = codes.join(colors, 'numbers', how='outer') - >>> j_outer.inspect() - - - >>> j_outer.inspect(columns=['numbers', 'color']) - - [#] numbers color - ==================== - [0] 0 None - [1] 1 red - [2] 1 red - [3] 1 red - [4] 2 yellow - [5] 3 green - [6] 3 green - [7] 4 blue - [8] 5 None + [#] numbers_L color + ====================== + [0] 0 None + [1] 1 red + [2] 1 red + [3] 1 red + [4] 2 yellow + [5] 3 green + [6] 3 green + [7] 4 blue + [8] 5 None + + Consider two frames: country_codes_frame and country_names_frame + + >>> country_codes_frame.inspect() + [#] col_0 col_1 col_2 + ======================== + [0] 1 354 a + [1] 2 91 a + [2] 2 100 b + [3] 3 47 a + [4] 4 968 c + [5] 5 50 c + + + >>> country_names_frame.inspect() + [#] col_0 col_1 col_2 + =========================== + [0] 1 Iceland a + [1] 1 Ice-land a + [2] 2 India b + [3] 3 Norway a + [4] 4 Oman c + [5] 6 Germany c + + Join them on the 'col_0' and 'col_2' columns ('inner' join by default) + + >>> composite_join = country_codes_frame.join(country_names_frame, ['col_0', 'col_2']) + + + >>> composite_join.inspect() + [#] col_0 col_1_L col_2 col_1_R + ==================================== + [0] 1 354 a Ice-land + [1] 1 354 a Iceland + [2] 2 100 b India + [3] 3 47 a Norway + [4] 4 968 c Oman More examples can be found in the :ref:`user manual `. diff --git a/python-client/trustedanalytics/rest/frame.py b/python-client/trustedanalytics/rest/frame.py index eeaf4f3264..19220281c7 100644 --- a/python-client/trustedanalytics/rest/frame.py +++ b/python-client/trustedanalytics/rest/frame.py @@ -329,12 +329,14 @@ def aggregate_with_udf(self, frame, group_by_column_keys, aggregator_expression, aggregate_with_udf_function = get_group_by_aggregator_function(aggregator_expression, data_types) + key_indices = FrameSchema.get_indices_for_selected_columns(frame.schema, group_by_column_keys) + from itertools import imap arguments = { "frame": frame.uri, "aggregate_by_column_keys": group_by_column_keys, "column_names": names, "column_types": [get_rest_str_from_data_type(t) for t in data_types], - "udf": get_aggregator_udf_arg(frame, aggregate_with_udf_function, imap, output_schema, init_acc_values) + "udf": get_aggregator_udf_arg(frame, aggregate_with_udf_function, imap, key_indices, output_schema, init_acc_values) } return execute_new_frame_command('frame/aggregate_with_udf', arguments) @@ -468,12 +470,20 @@ def inspect(self, frame, n, offset, selected_columns, format_settings): return RowsInspection(data, schema, offset=offset, format_settings=format_settings) def join(self, left, right, left_on, right_on, how, name=None): + if left_on is None: + raise ValueError("Please provide column name on which join should be performed") + elif isinstance(left_on, basestring): + left_on = [left_on] if right_on is None: right_on = left_on + elif isinstance(right_on, basestring): + right_on = [right_on] + if len(left_on) != len(right_on): + raise ValueError("Please provide equal number of join columns") arguments = {"name": name, "how": how, - "left_frame": {"frame": left.uri, "join_column": left_on}, - "right_frame": {"frame": right.uri, "join_column": right_on} } + "left_frame": {"frame": left.uri, "join_columns": left_on}, + "right_frame": {"frame": right.uri, "join_columns": right_on} } return execute_new_frame_command('frame:/join', arguments) def copy(self, frame, columns=None, where=None, name=None): @@ -513,7 +523,18 @@ def group_by(self, frame, group_by_columns, aggregation): if arg == agg.count: aggregation_list.append({'function': agg.count, 'column_name': first_column_name, 'new_column_name': "count"}) else: - return FrameBackendRest.aggregate_with_udf(self, frame, group_by_columns, arg.aggregator, arg.output_schema, arg.init_values) + #validte the arguments + init_flag = False + if arg.init_values is None: + init_flag=True + else: + if len(arg.output_schema) == len(arg.init_values): + init_flag=True + + if init_flag == True: + return FrameBackendRest.aggregate_with_udf(self, frame, group_by_columns, arg.aggregator, arg.output_schema, arg.init_values) + else: + raise ValueError("Provide initial values for all column names in output schema or leave initial values as empty") elif isinstance(arg, dict): for k,v in arg.iteritems(): # leave the valid column check to the server diff --git a/python-client/trustedanalytics/rest/spark.py b/python-client/trustedanalytics/rest/spark.py index 1f0982181c..07712ce2b7 100644 --- a/python-client/trustedanalytics/rest/spark.py +++ b/python-client/trustedanalytics/rest/spark.py @@ -128,7 +128,7 @@ class RowWrapper(Row): def load_row(self, data): self._set_data(data) -def _wrap_aggregator_rows_function(frame, aggregator_function, aggregator_schema, init_acc_values, optional_schema=None): +def _wrap_aggregator_rows_function(frame, aggregator_function, key_indices, aggregator_schema, init_acc_values, optional_schema=None): """ Wraps a python row function, like one used for a filter predicate, such that it will be evaluated with using the expected 'row' object rather than @@ -144,20 +144,20 @@ def _wrap_aggregator_rows_function(frame, aggregator_function, aggregator_schema acc_wrapper = MutableRow(aggregator_schema) row_wrapper = RowWrapper(row_schema) - + key_indices_wrapper = key_indices def rows_func(rows): try: bson_data = bson.decode_all(rows)[0] rows_data = bson_data['array'] - key_indices = bson_data['keyindices'] + #key_indices = bson_data['keyindices'] acc_wrapper._set_data(list(init_acc_values)) for row in rows_data: row_wrapper.load_row(row) aggregator_function(acc_wrapper, row_wrapper) result = [] - for key_index in key_indices: - answer = [rows_data[0][key_index]] - result.extend(answer) + for key_index in key_indices_wrapper: + answer = rows_data[0][key_index] + result.append(answer) result.extend(acc_wrapper._get_data()) return numpy_to_bson_friendly(result) except Exception as e: diff --git a/python-client/trustedanalytics/rest/spark_helper.py b/python-client/trustedanalytics/rest/spark_helper.py index a34b48ab37..fba4d20c99 100644 --- a/python-client/trustedanalytics/rest/spark_helper.py +++ b/python-client/trustedanalytics/rest/spark_helper.py @@ -81,7 +81,7 @@ def iteration_ready_function(s, iterator): return iterator_function(iterator) -def get_aggregator_udf_arg(frame, aggregator_function, iteration_function, aggregator_schema, init_val, optional_schema=None): +def get_aggregator_udf_arg(frame, aggregator_function, iteration_function, key_indices, aggregator_schema, init_val, optional_schema=None): """ Prepares a python row function for server execution and http transmission @@ -95,7 +95,7 @@ def get_aggregator_udf_arg(frame, aggregator_function, iteration_function, aggre the iteration function to apply for the frame. In general, it is imap. For filter however, it is ifilter """ - row_ready_function = _wrap_aggregator_rows_function(frame, aggregator_function, aggregator_schema, init_val, optional_schema) + row_ready_function = _wrap_aggregator_rows_function(frame, aggregator_function, key_indices, aggregator_schema, init_val, optional_schema) def iterator_function(iterator): return iteration_function(row_ready_function, iterator) def iteration_ready_function(s, iterator): return iterator_function(iterator) x = make_http_ready(iteration_ready_function)