From 702752b3f9a6c14d8f7523e6aa3a7a0183b6971a Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Thu, 17 Mar 2016 14:06:13 -0700 Subject: [PATCH 01/15] Added validation for init_values in aggregator in group by --- python-client/trustedanalytics/rest/frame.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python-client/trustedanalytics/rest/frame.py b/python-client/trustedanalytics/rest/frame.py index eeaf4f3264..22f39d1ab4 100644 --- a/python-client/trustedanalytics/rest/frame.py +++ b/python-client/trustedanalytics/rest/frame.py @@ -513,7 +513,17 @@ def group_by(self, frame, group_by_columns, aggregation): if arg == agg.count: aggregation_list.append({'function': agg.count, 'column_name': first_column_name, 'new_column_name': "count"}) else: - return FrameBackendRest.aggregate_with_udf(self, frame, group_by_columns, arg.aggregator, arg.output_schema, arg.init_values) + init_flag = False + if arg.init_values is None: + init_flag=True + else: + if len(arg.output_schema) == len(arg.init_values): + init_flag=True + + if init_flag == True: + return FrameBackendRest.aggregate_with_udf(self, frame, group_by_columns, arg.aggregator, arg.output_schema, arg.init_values) + else: + raise ValueError("Provide initial values for all column names in output schema or leave initial values as empty") elif isinstance(arg, dict): for k,v in arg.iteritems(): # leave the valid column check to the server From d56cd4bed234a40494a10b71bde0761a59ddbc76 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Thu, 17 Mar 2016 16:13:18 -0700 Subject: [PATCH 02/15] added comments to rest/frame.py --- python-client/trustedanalytics/rest/frame.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python-client/trustedanalytics/rest/frame.py b/python-client/trustedanalytics/rest/frame.py index 22f39d1ab4..ed89e335ed 100644 --- a/python-client/trustedanalytics/rest/frame.py +++ b/python-client/trustedanalytics/rest/frame.py @@ -513,6 +513,7 @@ def group_by(self, frame, group_by_columns, aggregation): if arg == agg.count: aggregation_list.append({'function': agg.count, 'column_name': first_column_name, 'new_column_name': "count"}) else: + #validte the arguments init_flag = False if arg.init_values is None: init_flag=True From 67d1dadd0b234c02cc5a986eadec4d4d45d3aa24 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Thu, 24 Mar 2016 12:25:00 -0700 Subject: [PATCH 03/15] Modfiications to group_by, removed unnecessary keyindices bson transfer --- .../atk/engine/frame/PythonRddStorage.scala | 9 ++++----- python-client/trustedanalytics/rest/frame.py | 4 +++- python-client/trustedanalytics/rest/spark.py | 11 ++++++----- python-client/trustedanalytics/rest/spark_helper.py | 4 ++-- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala index f119ec952f..9a1dc08f70 100644 --- a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala +++ b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala @@ -92,10 +92,10 @@ object PythonRddStorage { //Create a new schema which includes keys (KeyedSchema). val keyedSchema = udfSchema.copy(columns = data.frameSchema.columns(aggregateByColumnKeys) ++ udfSchema.columns) //track key indices to fetch data during BSON decode. - val keyIndices = for (key <- aggregateByColumnKeys) yield data.frameSchema.columnIndex(key) + //val keyIndices = for (key <- aggregateByColumnKeys) yield data.frameSchema.columnIndex(key) val converter = DataTypes.parseMany(keyedSchema.columns.map(_.dataType).toArray)(_) val groupRDD = data.groupByRows(row => row.values(aggregateByColumnKeys)) - val pyRdd = aggregateRddToPyRdd(udf, groupRDD, keyIndices, sc) + val pyRdd = aggregateRddToPyRdd(udf, groupRDD, sc) val frameRdd = getRddFromPythonRdd(pyRdd, converter) FrameRdd.toFrameRdd(keyedSchema, frameRdd) } @@ -197,10 +197,9 @@ object PythonRddStorage { * This method encodes the raw rdd into Bson to convert into PythonRDD * @param udf UDF provided by user to apply on each row * @param rdd rdd(List[keys], List[Rows]) - * @param keyIndices List of key indices, used to retreive key data with result frame * @return PythonRdd */ - def aggregateRddToPyRdd(udf: Udf, rdd: RDD[(List[Any], Iterable[Row])], keyIndices: List[Int], sc: SparkContext): EnginePythonRdd[Array[Byte]] = { + def aggregateRddToPyRdd(udf: Udf, rdd: RDD[(List[Any], Iterable[Row])], sc: SparkContext): EnginePythonRdd[Array[Byte]] = { val predicateInBytes = decodePythonBase64EncodedStrToBytes(udf.function) val baseRdd: RDD[Array[Byte]] = rdd.map { case (key, rows) => { @@ -214,7 +213,7 @@ object PythonRddStorage { case value => value } }).toArray - obj.put("keyindices", keyIndices.toArray) + //obj.put("keyindices", keyIndices.toArray) obj.put("array", bsonRows) BSON.encode(obj) } diff --git a/python-client/trustedanalytics/rest/frame.py b/python-client/trustedanalytics/rest/frame.py index ed89e335ed..6f3f25d7e8 100644 --- a/python-client/trustedanalytics/rest/frame.py +++ b/python-client/trustedanalytics/rest/frame.py @@ -329,12 +329,14 @@ def aggregate_with_udf(self, frame, group_by_column_keys, aggregator_expression, aggregate_with_udf_function = get_group_by_aggregator_function(aggregator_expression, data_types) + key_indices = FrameSchema.get_indices_for_selected_columns(frame.schema, group_by_column_keys) + from itertools import imap arguments = { "frame": frame.uri, "aggregate_by_column_keys": group_by_column_keys, "column_names": names, "column_types": [get_rest_str_from_data_type(t) for t in data_types], - "udf": get_aggregator_udf_arg(frame, aggregate_with_udf_function, imap, output_schema, init_acc_values) + "udf": get_aggregator_udf_arg(frame, aggregate_with_udf_function, imap, key_indices, output_schema, init_acc_values) } return execute_new_frame_command('frame/aggregate_with_udf', arguments) diff --git a/python-client/trustedanalytics/rest/spark.py b/python-client/trustedanalytics/rest/spark.py index 1f0982181c..c3ba86ba17 100644 --- a/python-client/trustedanalytics/rest/spark.py +++ b/python-client/trustedanalytics/rest/spark.py @@ -128,7 +128,7 @@ class RowWrapper(Row): def load_row(self, data): self._set_data(data) -def _wrap_aggregator_rows_function(frame, aggregator_function, aggregator_schema, init_acc_values, optional_schema=None): +def _wrap_aggregator_rows_function(frame, aggregator_function, key_indices, aggregator_schema, init_acc_values, optional_schema=None): """ Wraps a python row function, like one used for a filter predicate, such that it will be evaluated with using the expected 'row' object rather than @@ -144,20 +144,21 @@ def _wrap_aggregator_rows_function(frame, aggregator_function, aggregator_schema acc_wrapper = MutableRow(aggregator_schema) row_wrapper = RowWrapper(row_schema) + key_indices_wrapper = key_indices def rows_func(rows): try: bson_data = bson.decode_all(rows)[0] rows_data = bson_data['array'] - key_indices = bson_data['keyindices'] + #key_indices = bson_data['keyindices'] acc_wrapper._set_data(list(init_acc_values)) for row in rows_data: row_wrapper.load_row(row) aggregator_function(acc_wrapper, row_wrapper) result = [] - for key_index in key_indices: - answer = [rows_data[0][key_index]] - result.extend(answer) + for key_index in key_indices_wrapper: + answer = rows_data[0][key_index] + result.append(answer) result.extend(acc_wrapper._get_data()) return numpy_to_bson_friendly(result) except Exception as e: diff --git a/python-client/trustedanalytics/rest/spark_helper.py b/python-client/trustedanalytics/rest/spark_helper.py index a34b48ab37..fba4d20c99 100644 --- a/python-client/trustedanalytics/rest/spark_helper.py +++ b/python-client/trustedanalytics/rest/spark_helper.py @@ -81,7 +81,7 @@ def iteration_ready_function(s, iterator): return iterator_function(iterator) -def get_aggregator_udf_arg(frame, aggregator_function, iteration_function, aggregator_schema, init_val, optional_schema=None): +def get_aggregator_udf_arg(frame, aggregator_function, iteration_function, key_indices, aggregator_schema, init_val, optional_schema=None): """ Prepares a python row function for server execution and http transmission @@ -95,7 +95,7 @@ def get_aggregator_udf_arg(frame, aggregator_function, iteration_function, aggre the iteration function to apply for the frame. In general, it is imap. For filter however, it is ifilter """ - row_ready_function = _wrap_aggregator_rows_function(frame, aggregator_function, aggregator_schema, init_val, optional_schema) + row_ready_function = _wrap_aggregator_rows_function(frame, aggregator_function, key_indices, aggregator_schema, init_val, optional_schema) def iterator_function(iterator): return iteration_function(row_ready_function, iterator) def iteration_ready_function(s, iterator): return iterator_function(iterator) x = make_http_ready(iteration_ready_function) From aee3226aba01e417a09fa7bd71c936fceb01870b Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Thu, 31 Mar 2016 17:21:40 -0700 Subject: [PATCH 04/15] added timers to serialization and deserialization --- .../plugins/UnflattenColumnsPlugin.scala | 7 +++ .../atk/engine/frame/PythonRddStorage.scala | 59 +++++++++++++++---- python-client/trustedanalytics/rest/spark.py | 1 - 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala index 0cff524f7f..23defad426 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala @@ -68,7 +68,14 @@ class UnflattenColumnPlugin extends SparkCommandPlugin[UnflattenColumnArgs, Unit // run the operation val targetSchema = UnflattenColumnFunctions.createTargetSchema(schema, compositeKeyNames) + + //added timer for unflatten + println(s"Row Count before Unflatten groupby ${frame.rdd.count()}") + val start = System.nanoTime() val initialRdd = frame.rdd.groupByRows(row => row.values(compositeKeyNames)) + val end = System.nanoTime() + initialRdd.count() + println(s"Unflatten Groupby Time ${end - start}") val resultRdd = UnflattenColumnFunctions.unflattenRddByCompositeKey(compositeKeyIndices, initialRdd, targetSchema, arguments.delimiter.getOrElse(defaultDelimiter)) frame.save(new FrameRdd(targetSchema, resultRdd)) diff --git a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala index 9a1dc08f70..8fe1245fea 100644 --- a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala +++ b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala @@ -20,8 +20,7 @@ import java.io.File import java.util import org.trustedanalytics.atk.moduleloader.ClassLoaderAware -import org.trustedanalytics.atk.domain.frame.FrameReference -import org.trustedanalytics.atk.domain.frame.Udf +import org.trustedanalytics.atk.domain.frame.{ FrameReference, Udf } import org.trustedanalytics.atk.domain.schema.{ Column, DataTypes, Schema } import org.trustedanalytics.atk.engine.plugin.Invocation import org.trustedanalytics.atk.engine.{ SparkContextFactory, EngineConfig } @@ -75,15 +74,21 @@ object PythonRddStorage { udfSchema } val converter = DataTypes.parseMany(newSchema.columns.map(_.dataType).toArray)(_) - val pyRdd = rddToPyRdd(udf, data, sc) - val frameRdd = getRddFromPythonRdd(pyRdd, converter) + val accumulatorSer = sc.accumulator(0L, "mytimerSerGeneric") + val accumulatorDeSer = sc.accumulator(0L, "mytimerDeSerGeneric") + val pyRdd = rddToPyRdd(udf, data, sc, accumulatorSer) + val frameRdd = getRddFromPythonRdd(pyRdd, converter, accumulatorDeSer) + println(s"MytimerSer in mapWith took ${accumulatorSer.value}") + frameRdd.count() + println(s"MytimerDeSer in mapWith took ${accumulatorDeSer.value}") FrameRdd.toFrameRdd(newSchema, frameRdd) } /** * This method returns a FrameRdd after applying UDF on referencing FrameRdd + * * @param data Current referencing FrameRdd - * @param aggregateByColumnKeys List of column name(s) based on which aggregation is performed + * @param aggregateByColumnKeys List of column name(s) based on which yeahaggregation is performed * @param udf User Defined function(UDF) to apply on each row * @param udfSchema Mandatory output schema * @return FrameRdd @@ -94,9 +99,17 @@ object PythonRddStorage { //track key indices to fetch data during BSON decode. //val keyIndices = for (key <- aggregateByColumnKeys) yield data.frameSchema.columnIndex(key) val converter = DataTypes.parseMany(keyedSchema.columns.map(_.dataType).toArray)(_) + val groupRDD = data.groupByRows(row => row.values(aggregateByColumnKeys)) - val pyRdd = aggregateRddToPyRdd(udf, groupRDD, sc) - val frameRdd = getRddFromPythonRdd(pyRdd, converter) + + val accumulatorSer = sc.accumulator(0L, "mytimerSerAggregated") + val accumulatorDeSer = sc.accumulator(0L, "mytimerDeSerAggregated") + val pyRdd = aggregateRddToPyRdd(udf, groupRDD, sc, accumulatorSer) + val frameRdd = getRddFromPythonRdd(pyRdd, converter, accumulatorDeSer) + //serialization timer + println(s"MytimerSer in AggregateUDF took ${accumulatorSer.value}") + frameRdd.count() + println(s"MytimerDeSer in AggregateUDF took ${accumulatorDeSer.value}") FrameRdd.toFrameRdd(keyedSchema, frameRdd) } @@ -126,11 +139,12 @@ object PythonRddStorage { bsonList } - def rddToPyRdd(udf: Udf, rdd: RDD[Row], sc: SparkContext): EnginePythonRdd[Array[Byte]] = { + def rddToPyRdd(udf: Udf, rdd: RDD[Row], sc: SparkContext, acc: Accumulator[Long] = null): EnginePythonRdd[Array[Byte]] = { val predicateInBytes = decodePythonBase64EncodedStrToBytes(udf.function) // Create an RDD of byte arrays representing bson objects val baseRdd: RDD[Array[Byte]] = rdd.map( x => { + val start = System.nanoTime() val obj = new BasicBSONObject() obj.put("array", x.toSeq.toArray.map { case y: ArrayBuffer[_] => iterableToBsonList(y) @@ -138,15 +152,21 @@ object PythonRddStorage { case y: scala.collection.mutable.Seq[_] => iterableToBsonList(y) case value => value }) - BSON.encode(obj) + val res = BSON.encode(obj) + println(s"Bson Encoded obj: ${res}") + if (acc != null) + acc += (System.nanoTime() - start) + res } ) + println(s"RddToPyRddGeneric Bytes ${baseRdd.first()} ${baseRdd.first().length}") val pyRdd = getPyRdd(udf, sc, baseRdd, predicateInBytes) pyRdd } /** * This method converts the base RDD into Python RDD which is processed by the Python VM at the server. + * * @param udf UDF provided by the user * @param baseRdd Base RDD in Array[Bytes] * @param predicateInBytes UDF in Array[Bytes] @@ -195,14 +215,16 @@ object PythonRddStorage { /** * This method encodes the raw rdd into Bson to convert into PythonRDD + * * @param udf UDF provided by user to apply on each row * @param rdd rdd(List[keys], List[Rows]) * @return PythonRdd */ - def aggregateRddToPyRdd(udf: Udf, rdd: RDD[(List[Any], Iterable[Row])], sc: SparkContext): EnginePythonRdd[Array[Byte]] = { + def aggregateRddToPyRdd(udf: Udf, rdd: RDD[(List[Any], Iterable[Row])], sc: SparkContext, acc: Accumulator[Long] = null): EnginePythonRdd[Array[Byte]] = { val predicateInBytes = decodePythonBase64EncodedStrToBytes(udf.function) val baseRdd: RDD[Array[Byte]] = rdd.map { case (key, rows) => { + val x = System.nanoTime() val obj = new BasicBSONObject() val bsonRows = rows.map( row => { @@ -215,9 +237,15 @@ object PythonRddStorage { }).toArray //obj.put("keyindices", keyIndices.toArray) obj.put("array", bsonRows) - BSON.encode(obj) + val res = BSON.encode(obj) + println(s"Bson Encoded obj: ${res}") + val y = System.nanoTime() + if (acc != null) + acc += (y - x) + res } } + println(s"agg12-RddToPyRddAggregation Bytes ${baseRdd.first()} ${baseRdd.first().length}") val pyRdd = getPyRdd(udf, sc, baseRdd, predicateInBytes) pyRdd } @@ -243,18 +271,23 @@ object PythonRddStorage { result.headOption } - def getRddFromPythonRdd(pyRdd: EnginePythonRdd[Array[Byte]], converter: (Array[Any] => Array[Any]) = null): RDD[Array[Any]] = { + def getRddFromPythonRdd(pyRdd: EnginePythonRdd[Array[Byte]], converter: (Array[Any] => Array[Any]) = null, acc: Accumulator[Long] = null): RDD[Array[Any]] = { val resultRdd = pyRdd.flatMap(s => { + val start = System.nanoTime() //should be BasicBSONList containing only BasicBSONList objects val bson = BSON.decode(s) val asList = bson.get("array").asInstanceOf[BasicBSONList] - asList.map(innerList => { + val res = asList.map(innerList => { val asBsonList = innerList.asInstanceOf[BasicBSONList] asBsonList.map { case x: BasicBSONList => x.toArray case value => value }.toArray.asInstanceOf[Array[Any]] }) + val end = System.nanoTime() + if (acc != null) + acc += (end - start) + res }).map(converter) resultRdd diff --git a/python-client/trustedanalytics/rest/spark.py b/python-client/trustedanalytics/rest/spark.py index c3ba86ba17..07712ce2b7 100644 --- a/python-client/trustedanalytics/rest/spark.py +++ b/python-client/trustedanalytics/rest/spark.py @@ -145,7 +145,6 @@ def _wrap_aggregator_rows_function(frame, aggregator_function, key_indices, aggr acc_wrapper = MutableRow(aggregator_schema) row_wrapper = RowWrapper(row_schema) key_indices_wrapper = key_indices - def rows_func(rows): try: bson_data = bson.decode_all(rows)[0] From ec816c1f24184ef2bf1b97734a0fa228a1ce8378 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Mon, 4 Apr 2016 09:37:25 -0700 Subject: [PATCH 05/15] added changes to join plugin to support composite key --- .../engine/frame/plugins/join/JoinArgs.scala | 10 ++--- .../frame/plugins/join/JoinPlugin.scala | 22 +++++----- .../frame/plugins/join/JoinRddFunctions.scala | 43 +++++++++++++++++-- .../frame/plugins/join/RddJoinParam.scala | 6 +-- python-client/trustedanalytics/core/frame.py | 4 +- python-client/trustedanalytics/rest/frame.py | 4 +- 6 files changed, 63 insertions(+), 26 deletions(-) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala index b903fd7727..611772216c 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala @@ -24,15 +24,15 @@ import org.trustedanalytics.atk.engine.plugin.{ ArgDoc, Invocation } * Arguments for Join plugin * */ -case class JoinArgs(leftFrame: JoinFrameArgs, +case class JoinArgs(leftFrame: JoinFrameArgs, @ArgDoc("""Join arguments for first data frame.""") rightFrame: JoinFrameArgs, @ArgDoc("""Methods of join (inner, left, right or outer).""") how: String, @ArgDoc("""Name of new frame to be created.""") name: Option[String] = None, @ArgDoc("""The type of skewed join: 'skewedhash' or 'skewedbroadcast'""") skewedJoinType: Option[String] = None) { require(leftFrame != null && leftFrame.frame != null, "left frame is required") require(rightFrame != null && rightFrame.frame != null, "right frame is required") - require(leftFrame.joinColumn != null, "left join column is required") - require(rightFrame.joinColumn != null, "right join column is required") + require(leftFrame.joinColumns != null, "left join column is required") + require(rightFrame.joinColumns != null, "right join column is required") require(how != null, "join method is required") require(skewedJoinType.isEmpty || (skewedJoinType.get == "skewedhash" || skewedJoinType.get == "skewedbroadcast"), @@ -43,6 +43,6 @@ case class JoinArgs(leftFrame: JoinFrameArgs, * Join arguments for frame * * @param frame Data frame - * @param joinColumn Join column name + * @param joinColumns Join column name */ -case class JoinFrameArgs(frame: FrameReference, joinColumn: String) +case class JoinFrameArgs(frame: FrameReference, joinColumns: List[String]) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala index 8feb573fec..99978961ec 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala @@ -69,19 +69,21 @@ class JoinPlugin extends SparkCommandPlugin[JoinArgs, FrameReference] { val rightFrame: SparkFrame = arguments.rightFrame.frame //first validate join columns are valid - leftFrame.schema.validateColumnsExist(List(arguments.leftFrame.joinColumn)) - rightFrame.schema.validateColumnsExist(List(arguments.rightFrame.joinColumn)) - require(DataTypes.isCompatibleDataType( - leftFrame.schema.columnDataType(arguments.leftFrame.joinColumn), - rightFrame.schema.columnDataType(arguments.rightFrame.joinColumn)), - "Join columns must have compatible data types") + leftFrame.schema.validateColumnsExist(arguments.leftFrame.joinColumns) + rightFrame.schema.validateColumnsExist(arguments.rightFrame.joinColumns) + + //Check left join column is compatiable with right join column + (arguments.leftFrame.joinColumns zip arguments.rightFrame.joinColumns).map{ case(leftJoinCol, rightJoinCol) => require(DataTypes.isCompatibleDataType( + leftFrame.schema.columnDataType(leftJoinCol), + rightFrame.schema.columnDataType(rightJoinCol)), + "Join columns must have compatible data types")} // Get estimated size of frame to determine whether to use a broadcast join val broadcastJoinThreshold = EngineConfig.broadcastJoinThreshold val joinedFrame = JoinRddFunctions.join( - createRDDJoinParam(leftFrame, arguments.leftFrame.joinColumn, broadcastJoinThreshold), - createRDDJoinParam(rightFrame, arguments.rightFrame.joinColumn, broadcastJoinThreshold), + createRDDJoinParam(leftFrame, arguments.leftFrame.joinColumns, broadcastJoinThreshold), + createRDDJoinParam(rightFrame, arguments.rightFrame.joinColumns, broadcastJoinThreshold), arguments.how, broadcastJoinThreshold, arguments.skewedJoinType @@ -95,14 +97,14 @@ class JoinPlugin extends SparkCommandPlugin[JoinArgs, FrameReference] { //Create parameters for join private def createRDDJoinParam(frame: SparkFrame, - joinColumn: String, + joinColumns: Seq[String], broadcastJoinThreshold: Long): RddJoinParam = { val frameSize = if (broadcastJoinThreshold > 0) frame.sizeInBytes else None val estimatedRddSize = frameSize match { case Some(size) => Some((size * EngineConfig.frameCompressionRatio).toLong) case _ => None } - RddJoinParam(frame.rdd, joinColumn, estimatedRddSize) + RddJoinParam(frame.rdd, joinColumns, estimatedRddSize) } } diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala index 5d980b83dd..309780131b 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala @@ -88,7 +88,7 @@ object JoinRddFunctions extends Serializable { val rightFrame = right.frame.toDataFrame val joinedFrame = leftFrame.join( rightFrame, - leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)) + left.joinColumns ) joinedFrame.rdd } @@ -108,13 +108,26 @@ object JoinRddFunctions extends Serializable { def fullOuterJoin(left: RddJoinParam, right: RddJoinParam): RDD[Row] = { val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame + + val columnsTuple=left.joinColumns.zip(right.joinColumns) + + var exps= makeExpression(columnsTuple.head._1, columnsTuple.head._2) + + columnsTuple.tail.map{ case(lc, rc) => exps= exps && makeExpression(lc, rc)} + + def makeExpression(leftCol: String, rightCol:String): Column ={ + leftFrame(leftCol).equalTo(rightFrame(rightCol)) + } + val joinedFrame = leftFrame.join(rightFrame, - leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)), + exps, joinType = "fullouter" ) joinedFrame.rdd } + + /** * Perform right-outer join * @@ -138,8 +151,19 @@ object JoinRddFunctions extends Serializable { case _ => val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame + + val columnsTuple=left.joinColumns.zip(right.joinColumns) + + var exps= makeExpression(columnsTuple.head._1, columnsTuple.head._2) + + columnsTuple.tail.map{ case(lc, rc) => exps= exps && makeExpression(lc, rc)} + + def makeExpression(leftCol: String, rightCol:String): Column ={ + leftFrame(leftCol).equalTo(rightFrame(rightCol)) + } + val joinedFrame = leftFrame.join(rightFrame, - leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)), + exps, joinType = "right" ) joinedFrame.rdd @@ -167,8 +191,19 @@ object JoinRddFunctions extends Serializable { case _ => val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame + + val columnsTuple=left.joinColumns.zip(right.joinColumns) + + var exps= makeExpression(columnsTuple.head._1, columnsTuple.head._2) + + columnsTuple.tail.map{ case(lc, rc) => exps= exps && makeExpression(lc, rc)} + + def makeExpression(leftCol: String, rightCol:String): Column ={ + leftFrame(leftCol).equalTo(rightFrame(rightCol)) + } + val joinedFrame = leftFrame.join(rightFrame, - leftFrame(left.joinColumn).equalTo(rightFrame(right.joinColumn)), + exps, joinType = "left" ) joinedFrame.rdd diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/RddJoinParam.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/RddJoinParam.scala index 325a522b40..20673e41d4 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/RddJoinParam.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/RddJoinParam.scala @@ -22,14 +22,14 @@ import org.apache.spark.frame.FrameRdd * Join parameters for RDD * * @param frame Frame used for join - * @param joinColumn Join column name + * @param joinColumns Join column name * @param estimatedSizeInBytes Optional estimated size of RDD in bytes used to determine whether to use a broadcast join */ case class RddJoinParam(frame: FrameRdd, - joinColumn: String, + joinColumns: Seq[String], estimatedSizeInBytes: Option[Long] = None) { require(frame != null, "join frame is required") - require(joinColumn != null, "join column is required") + require(joinColumns != null, "join column(s) are required") require(estimatedSizeInBytes.isEmpty || estimatedSizeInBytes.get > 0, "Estimated rdd size in bytes should be empty or greater than zero") } diff --git a/python-client/trustedanalytics/core/frame.py b/python-client/trustedanalytics/core/frame.py index 6ea4d7effc..58f1194546 100644 --- a/python-client/trustedanalytics/core/frame.py +++ b/python-client/trustedanalytics/core/frame.py @@ -1075,8 +1075,8 @@ def __inspect(self, @api @beta @arg('right', 'Frame', "Another frame to join with") - @arg('left_on', str, "Name of the column in the left frame used to match up the two frames.") - @arg('right_on', str, "Name of the column in the right frame used to match up the two frames. " + @arg('left_on', list, "Names of the columns in the left frame used to match up the two frames.") + @arg('right_on', list, "Names of the columns in the right frame used to match up the two frames. " "Default is the same as the left frame.") @arg('how', str, "How to qualify the data to be joined together. Must be one of the following: " "'left', 'right', 'inner', 'outer'. Default is 'inner'") diff --git a/python-client/trustedanalytics/rest/frame.py b/python-client/trustedanalytics/rest/frame.py index 6f3f25d7e8..a9f4d2f84b 100644 --- a/python-client/trustedanalytics/rest/frame.py +++ b/python-client/trustedanalytics/rest/frame.py @@ -474,8 +474,8 @@ def join(self, left, right, left_on, right_on, how, name=None): right_on = left_on arguments = {"name": name, "how": how, - "left_frame": {"frame": left.uri, "join_column": left_on}, - "right_frame": {"frame": right.uri, "join_column": right_on} } + "left_frame": {"frame": left.uri, "join_columns": left_on}, + "right_frame": {"frame": right.uri, "join_columns": right_on} } return execute_new_frame_command('frame:/join', arguments) def copy(self, frame, columns=None, where=None, name=None): From 5ee7b2833dc999cb8cd00997f598e9519323c173 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Mon, 4 Apr 2016 17:36:46 -0700 Subject: [PATCH 06/15] Added changes to sparkjointest cases and modified all join related files to support composite keys --- .../join/BroadcastJoinRddFunctions.scala | 47 +++-- .../engine/frame/plugins/join/JoinArgs.scala | 2 +- .../plugins/join/JoinBroadcastVariable.scala | 4 +- .../frame/plugins/join/JoinPlugin.scala | 10 +- .../frame/plugins/join/JoinRddFunctions.scala | 68 +++--- .../join/JoinBroadcastVariableITest.scala | 8 +- .../frame/plugins/join/SparkJoinITest.scala | 194 +++++++++--------- 7 files changed, 163 insertions(+), 170 deletions(-) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala index 6982684c61..ca7ff83b34 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala @@ -20,6 +20,7 @@ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.trustedanalytics.atk.engine.frame.RowWrapper /** * Functions for joining pair RDDs using broadcast variables @@ -30,19 +31,22 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Perform left outer-join using a broadcast variable * * @param other join parameter for second data frame - * - * @return key-value RDD whose values are results of left-outer join + * @return key-value RDD whose values are results of left-outer join */ def leftBroadcastJoin(other: RddJoinParam): RDD[Row] = { val rightBroadcastVariable = JoinBroadcastVariable(other) lazy val rightNullRow: Row = new GenericRow(other.frame.numColumns) + val rowWrapper = new RowWrapper(other.frame.frameSchema) + val rightColsToKeep = other.frame.frameSchema.dropColumns(other.joinColumns.toList).columnNames self.frame.flatMapRows(left => { - val leftKey = left.value(self.joinColumn) - rightBroadcastVariable.get(leftKey) match { - case Some(rightRowSet) => for (rightRow <- rightRowSet) yield Row.merge(left.row, rightRow) + + val leftKeys = left.values(self.joinColumns.toList) + rightBroadcastVariable.get(leftKeys) match { + case Some(rightRowSet) => for (rightRow <- rightRowSet) yield Row.merge(left.row, new GenericRow(rowWrapper(rightRow).values(rightColsToKeep).toArray)) case _ => List(Row.merge(left.row, rightNullRow.copy())) } + }) } @@ -50,17 +54,19 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Right outer-join using a broadcast variable * * @param other join parameter for second data frame - * - * @return key-value RDD whose values are results of right-outer join + * @return key-value RDD whose values are results of right-outer join */ def rightBroadcastJoin(other: RddJoinParam): RDD[Row] = { val leftBroadcastVariable = JoinBroadcastVariable(self) lazy val leftNullRow: Row = new GenericRow(self.frame.numColumns) + val rowWrapper = new RowWrapper(other.frame.frameSchema) other.frame.flatMapRows(right => { - val rightKey = right.value(other.joinColumn) - leftBroadcastVariable.get(rightKey) match { - case Some(leftRowSet) => for (leftRow <- leftRowSet) yield Row.merge(leftRow, right.row) + // val rightKey = right.value(other.joinColumn) + val leftColsToKeep = self.frame.frameSchema.dropColumns(self.joinColumns.toList).columnNames + val rightKeys = right.values(other.joinColumns.toList) + leftBroadcastVariable.get(rightKeys) match { + case Some(leftRowSet) => for (leftRow <- leftRowSet) yield Row.merge(new GenericRow(rowWrapper(leftRow).values(leftColsToKeep).toArray), right.row) case _ => List(Row.merge(leftNullRow.copy(), right.row)) } }) @@ -70,20 +76,23 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Inner-join using a broadcast variable * * @param other join parameter for second data frame - * - * @return key-value RDD whose values are results of inner-outer join + * @return key-value RDD whose values are results of inner-outer join */ def innerBroadcastJoin(other: RddJoinParam, broadcastJoinThreshold: Long): RDD[Row] = { val leftSizeInBytes = self.estimatedSizeInBytes.getOrElse(Long.MaxValue) val rightSizeInBytes = other.estimatedSizeInBytes.getOrElse(Long.MaxValue) + val rowWrapper = new RowWrapper(other.frame.frameSchema) val innerJoinedRDD = if (rightSizeInBytes <= broadcastJoinThreshold) { val rightBroadcastVariable = JoinBroadcastVariable(other) + + val rightColsToKeep = other.frame.frameSchema.dropColumns(other.joinColumns.toList).columnNames self.frame.flatMapRows(left => { - val leftKey = left.value(self.joinColumn) - rightBroadcastVariable.get(leftKey) match { + // val leftKey = left.value(self.joinColumn) + val leftKeys = left.values(self.joinColumns.toList) + rightBroadcastVariable.get(leftKeys) match { case Some(rightRowSet) => - for (rightRow <- rightRowSet) yield Row.merge(left.row, rightRow) + for (rightRow <- rightRowSet) yield Row.merge(left.row, new GenericRow(rowWrapper(rightRow).values(rightColsToKeep).toArray)) case _ => Set.empty[Row] } }) @@ -91,10 +100,12 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali else if (leftSizeInBytes <= broadcastJoinThreshold) { val leftBroadcastVariable = JoinBroadcastVariable(self) other.frame.flatMapRows(rightRow => { - val rightKey = rightRow.value(other.joinColumn) - leftBroadcastVariable.get(rightKey) match { + // val rightKey = rightRow.value(other.joinColumn) + val leftColsToKeep = self.frame.frameSchema.dropColumns(self.joinColumns.toList).columnNames + val rightKeys = rightRow.values(other.joinColumns.toList) + leftBroadcastVariable.get(rightKeys) match { case Some(leftRowSet) => - for (leftRow <- leftRowSet) yield Row.merge(leftRow, rightRow.row) + for (leftRow <- leftRowSet) yield Row.merge(new GenericRow(rowWrapper(leftRow).values(leftColsToKeep).toArray), rightRow.row) case _ => Set.empty[Row] } }) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala index 611772216c..81ec4a5ac4 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinArgs.scala @@ -24,7 +24,7 @@ import org.trustedanalytics.atk.engine.plugin.{ ArgDoc, Invocation } * Arguments for Join plugin * */ -case class JoinArgs(leftFrame: JoinFrameArgs, +case class JoinArgs(leftFrame: JoinFrameArgs, @ArgDoc("""Join arguments for first data frame.""") rightFrame: JoinFrameArgs, @ArgDoc("""Methods of join (inner, left, right or outer).""") how: String, @ArgDoc("""Name of new frame to be created.""") name: Option[String] = None, diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala index b88606c6a8..89549c0626 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala @@ -52,7 +52,9 @@ case class JoinBroadcastVariable(joinParam: RddJoinParam) { // Create the broadcast variable for the join private def createBroadcastMultiMaps(joinParam: RddJoinParam): Broadcast[MultiMap[Any, Row]] = { //Grouping by key to ensure that duplicate keys are not split across different broadcast variables - val broadcastList = joinParam.frame.groupByRows(row => row.value(joinParam.joinColumn)).collect().toList + //val broadcastList = joinParam.frame.groupByRows(row => row.value(joinParam.joinColumn)).collect().toList + + val broadcastList = joinParam.frame.groupByRows(row => row.values(joinParam.joinColumns.toList)).collect().toList val broadcastMultiMap = listToMultiMap(broadcastList) joinParam.frame.sparkContext.broadcast(broadcastMultiMap) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala index 99978961ec..df24e2d02a 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinPlugin.scala @@ -73,10 +73,12 @@ class JoinPlugin extends SparkCommandPlugin[JoinArgs, FrameReference] { rightFrame.schema.validateColumnsExist(arguments.rightFrame.joinColumns) //Check left join column is compatiable with right join column - (arguments.leftFrame.joinColumns zip arguments.rightFrame.joinColumns).map{ case(leftJoinCol, rightJoinCol) => require(DataTypes.isCompatibleDataType( - leftFrame.schema.columnDataType(leftJoinCol), - rightFrame.schema.columnDataType(rightJoinCol)), - "Join columns must have compatible data types")} + (arguments.leftFrame.joinColumns zip arguments.rightFrame.joinColumns).map { + case (leftJoinCol, rightJoinCol) => require(DataTypes.isCompatibleDataType( + leftFrame.schema.columnDataType(leftJoinCol), + rightFrame.schema.columnDataType(rightJoinCol)), + "Join columns must have compatible data types") + } // Get estimated size of frame to determine whether to use a broadcast join val broadcastJoinThreshold = EngineConfig.broadcastJoinThreshold diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala index 309780131b..69eb122196 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala @@ -69,7 +69,6 @@ object JoinRddFunctions extends Serializable { * @param left join parameter for first data frame * @param right join parameter for second data frame * @param broadcastJoinThreshold use broadcast variable for join if size of one of the data frames is below threshold - * * @return Joined RDD */ def innerJoin(left: RddJoinParam, @@ -102,22 +101,19 @@ object JoinRddFunctions extends Serializable { * * @param left join parameter for first data frame * @param right join parameter for second data frame - * * @return Joined RDD */ def fullOuterJoin(left: RddJoinParam, right: RddJoinParam): RDD[Row] = { val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame - val columnsTuple=left.joinColumns.zip(right.joinColumns) - - var exps= makeExpression(columnsTuple.head._1, columnsTuple.head._2) + val columnsTuple = left.joinColumns.zip(right.joinColumns) - columnsTuple.tail.map{ case(lc, rc) => exps= exps && makeExpression(lc, rc)} - - def makeExpression(leftCol: String, rightCol:String): Column ={ + def makeExpression(leftCol: String, rightCol: String): Column = { leftFrame(leftCol).equalTo(rightFrame(rightCol)) } + var exps = makeExpression(columnsTuple.head._1, columnsTuple.head._2) + columnsTuple.tail.map { case (lc, rc) => exps = exps && makeExpression(lc, rc) } val joinedFrame = leftFrame.join(rightFrame, exps, @@ -126,8 +122,6 @@ object JoinRddFunctions extends Serializable { joinedFrame.rdd } - - /** * Perform right-outer join * @@ -136,7 +130,6 @@ object JoinRddFunctions extends Serializable { * @param left join parameter for first data frame * @param right join parameter for second data frame * @param broadcastJoinThreshold use broadcast variable for join if size of first data frame is below threshold - * * @return Joined RDD */ def rightOuterJoin(left: RddJoinParam, @@ -152,15 +145,13 @@ object JoinRddFunctions extends Serializable { val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame - val columnsTuple=left.joinColumns.zip(right.joinColumns) - - var exps= makeExpression(columnsTuple.head._1, columnsTuple.head._2) + val columnsTuple = left.joinColumns.zip(right.joinColumns) - columnsTuple.tail.map{ case(lc, rc) => exps= exps && makeExpression(lc, rc)} - - def makeExpression(leftCol: String, rightCol:String): Column ={ + def makeExpression(leftCol: String, rightCol: String): Column = { leftFrame(leftCol).equalTo(rightFrame(rightCol)) } + var exps = makeExpression(columnsTuple.head._1, columnsTuple.head._2) + columnsTuple.tail.map { case (lc, rc) => exps = exps && makeExpression(lc, rc) } val joinedFrame = leftFrame.join(rightFrame, exps, @@ -178,7 +169,6 @@ object JoinRddFunctions extends Serializable { * @param left join parameter for first data frame * @param right join parameter for second data frame * @param broadcastJoinThreshold use broadcast variable for join if size of second data frame is below threshold - * * @return Joined RDD */ def leftOuterJoin(left: RddJoinParam, @@ -192,15 +182,13 @@ object JoinRddFunctions extends Serializable { val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame - val columnsTuple=left.joinColumns.zip(right.joinColumns) + val columnsTuple = left.joinColumns.zip(right.joinColumns) - var exps= makeExpression(columnsTuple.head._1, columnsTuple.head._2) - - columnsTuple.tail.map{ case(lc, rc) => exps= exps && makeExpression(lc, rc)} - - def makeExpression(leftCol: String, rightCol:String): Column ={ + def makeExpression(leftCol: String, rightCol: String): Column = { leftFrame(leftCol).equalTo(rightFrame(rightCol)) } + var exps = makeExpression(columnsTuple.head._1, columnsTuple.head._2) + columnsTuple.tail.map { case (lc, rc) => exps = exps && makeExpression(lc, rc) } val joinedFrame = leftFrame.join(rightFrame, exps, @@ -252,19 +240,26 @@ object JoinRddFunctions extends Serializable { def mergeJoinColumns(joinedRdd: RDD[Row], left: RddJoinParam, right: RddJoinParam): RDD[Row] = { + val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema - val leftJoinIndex = leftSchema.columnIndex(left.joinColumn) - val rightJoinIndex = rightSchema.columnIndex(right.joinColumn) + leftSchema.columns.size + // val leftJoinIndex = leftSchema.columnIndex(left.joinColumn) + // val rightJoinIndex = rightSchema.columnIndex(right.joinColumn) + leftSchema.columns.size + + val leftJoinIndices = leftSchema.columnIndices(left.joinColumns) + val rightJoinIndices = rightSchema.columnIndices(right.joinColumns).map(rightindex => rightindex + leftSchema.columns.size) joinedRdd.map(row => { - val leftKey = row.get(leftJoinIndex) - val rightKey = row.get(rightJoinIndex) - val newLeftKey = if (leftKey == null) rightKey else leftKey val rowArray = row.toSeq.toArray - rowArray(leftJoinIndex) = newLeftKey + leftJoinIndices.zip(rightJoinIndices).map { + case (leftindex, rightindex) => { + if (row.get(leftindex) == null) { + rowArray(leftindex) = row.get(rightindex) + } + } + } new GenericRow(rowArray) }) } @@ -285,12 +280,8 @@ object JoinRddFunctions extends Serializable { val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema - // Get unique name for left join column to drop - val oldSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) - val dropColumnName = oldSchema.column(leftSchema.columnIndex(left.joinColumn)).name - // Create new schema - val newLeftSchema = leftSchema.renameColumn(left.joinColumn, dropColumnName) + val newLeftSchema = leftSchema.dropColumns(left.joinColumns.toList) val newSchema = FrameSchema(Schema.join(newLeftSchema.columns, rightSchema.columns)) new FrameRdd(newSchema, joinedRdd) @@ -309,15 +300,12 @@ object JoinRddFunctions extends Serializable { def dropRightJoinColumn(joinedRdd: RDD[Row], left: RddJoinParam, right: RddJoinParam): FrameRdd = { + val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema - // Get unique name for right join column to drop - val oldSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) - val dropColumnName = oldSchema.column(leftSchema.columns.size + rightSchema.columnIndex(right.joinColumn)).name - // Create new schema - val newRightSchema = rightSchema.renameColumn(right.joinColumn, dropColumnName) + val newRightSchema = rightSchema.dropColumns(right.joinColumns.toList) val newSchema = FrameSchema(Schema.join(leftSchema.columns, newRightSchema.columns)) new FrameRdd(newSchema, joinedRdd) diff --git a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala index 09ac8614d8..95b7a0a95d 100644 --- a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala +++ b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala @@ -41,7 +41,7 @@ class JoinBroadcastVariableITest extends TestingSparkContextFlatSpec with Matche "JoinBroadcastVariable" should "create a single broadcast variable when RDD size is less than 2GB" in { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(idCountryNames)) - val joinParam = RddJoinParam(countryNames, "col_0", Some(150)) + val joinParam = RddJoinParam(countryNames, Seq("col_0"), Some(150)) val broadcastVariable = JoinBroadcastVariable(joinParam) @@ -57,7 +57,7 @@ class JoinBroadcastVariableITest extends TestingSparkContextFlatSpec with Matche "JoinBroadcastVariable" should "create a two broadcast variables when RDD size is equals 3GB" in { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(idCountryNames)) - val joinParam = RddJoinParam(countryNames, "col_0", Some(3L * 1024 * 1024 * 1024)) + val joinParam = RddJoinParam(countryNames, Seq("col_0"), Some(3L * 1024 * 1024 * 1024)) val broadcastVariable = JoinBroadcastVariable(joinParam) @@ -72,7 +72,7 @@ class JoinBroadcastVariableITest extends TestingSparkContextFlatSpec with Matche "JoinBroadcastVariable" should "create an empty broadcast variable" in { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(List.empty[Row])) - val joinParam = RddJoinParam(countryNames, "col_0", Some(3L * 1024 * 1024 * 1024)) + val joinParam = RddJoinParam(countryNames, Seq("col_0"), Some(3L * 1024 * 1024 * 1024)) val broadcastVariable = JoinBroadcastVariable(joinParam) @@ -86,7 +86,7 @@ class JoinBroadcastVariableITest extends TestingSparkContextFlatSpec with Matche "JoinBroadcastVariable" should "throw an Exception if column does not exist in frame" in { intercept[Exception] { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(idCountryNames)) - val joinParam = RddJoinParam(countryNames, "col_bad", Some(3L * 1024 * 1024 * 1024)) + val joinParam = RddJoinParam(countryNames, Seq("col_bad"), Some(3L * 1024 * 1024 * 1024)) JoinBroadcastVariable(joinParam) } } diff --git a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala index cdcd78fe97..bfd0dfb845 100644 --- a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala +++ b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala @@ -57,24 +57,23 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "inner") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "inner") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")) + new GenericRow(Array[Any](1, 354, "Iceland")), + new GenericRow(Array[Any](1, 354, "Ice-land")), + new GenericRow(Array[Any](2, 91, "India")), + new GenericRow(Array[Any](2, 100, "India")), + new GenericRow(Array[Any](3, 47, "Norway")), + new GenericRow(Array[Any](4, 968, "Oman")) ) results should contain theSameElementsAs expectedResults @@ -84,8 +83,8 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val leftJoinParam = RddJoinParam(countryCode, "col_0", Some(150)) - val rightJoinParam = RddJoinParam(countryNames, "col_0", Some(10000)) + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0"), Some(150)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0"), Some(10000)) val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "inner") val results = resultFrame.collect() @@ -93,17 +92,16 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")) + new GenericRow(Array[Any](1, 354, "Iceland")), + new GenericRow(Array[Any](1, 354, "Ice-land")), + new GenericRow(Array[Any](2, 91, "India")), + new GenericRow(Array[Any](2, 100, "India")), + new GenericRow(Array[Any](3, 47, "Norway")), + new GenericRow(Array[Any](4, 968, "Oman")) ) results should contain theSameElementsAs expectedResults @@ -113,25 +111,24 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "left") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "left") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](5, 50, null, null)) + new GenericRow(Array[Any](1, 354,"Iceland")), + new GenericRow(Array[Any](1, 354,"Ice-land")), + new GenericRow(Array[Any](2, 91,"India")), + new GenericRow(Array[Any](2, 100,"India")), + new GenericRow(Array[Any](3, 47,"Norway")), + new GenericRow(Array[Any](4, 968,"Oman")), + new GenericRow(Array[Any](5, 50, null)) ) results should contain theSameElementsAs expectedResults @@ -141,8 +138,8 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val leftJoinParam = RddJoinParam(countryCode, "col_0", Some(1500L)) - val rightJoinParam = RddJoinParam(countryNames, "col_0", Some(100L + Int.MaxValue)) + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0"), Some(1500L)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0"), Some(100L + Int.MaxValue)) // Test join wrapper function val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "left") @@ -151,18 +148,17 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](5, 50, null, null)) + new GenericRow(Array[Any](1, 354,"Iceland")), + new GenericRow(Array[Any](1, 354, "Ice-land")), + new GenericRow(Array[Any](2, 91, "India")), + new GenericRow(Array[Any](2, 100, "India")), + new GenericRow(Array[Any](3, 47, "Norway")), + new GenericRow(Array[Any](4, 968, "Oman")), + new GenericRow(Array[Any](5, 50, null)) ) results should contain theSameElementsAs expectedResults @@ -172,25 +168,24 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) val resultFrame = JoinRddFunctions.join( - RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "right") + RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "right") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0_L", DataTypes.int32, 0), - Column("col_1_L", DataTypes.int32, 1), - Column("col_0", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_L", DataTypes.int32, 0), + Column("col_0", DataTypes.int32, 1), + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](null, null, 6, "Germany")) + new GenericRow(Array[Any](354, 1, "Iceland")), + new GenericRow(Array[Any](354, 1, "Ice-land")), + new GenericRow(Array[Any](91, 2, "India")), + new GenericRow(Array[Any](100, 2, "India")), + new GenericRow(Array[Any](47, 3, "Norway")), + new GenericRow(Array[Any](968, 4, "Oman")), + new GenericRow(Array[Any](null, 6, "Germany")) ) results should contain theSameElementsAs expectedResults @@ -201,27 +196,26 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) val broadcastJoinThreshold = 1000 - val leftJoinParam = RddJoinParam(countryCode, "col_0", Some(800)) - val rightJoinParam = RddJoinParam(countryNames, "col_0", Some(4000)) + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0"), Some(800)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0"), Some(4000)) val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "right", broadcastJoinThreshold) val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0_L", DataTypes.int32, 0), - Column("col_1_L", DataTypes.int32, 1), - Column("col_0", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_L", DataTypes.int32, 0), + Column("col_0", DataTypes.int32, 1), + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](null, null, 6, "Germany")) + new GenericRow(Array[Any](354, 1, "Iceland")), + new GenericRow(Array[Any](354, 1, "Ice-land")), + new GenericRow(Array[Any](91, 2, "India")), + new GenericRow(Array[Any](100, 2, "India")), + new GenericRow(Array[Any](47, 3, "Norway")), + new GenericRow(Array[Any](968, 4, "Oman")), + new GenericRow(Array[Any](null, 6, "Germany")) ) results should contain theSameElementsAs expectedResults @@ -231,26 +225,25 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "outer") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "outer") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, 1, "Iceland")), - new GenericRow(Array[Any](1, 354, 1, "Ice-land")), - new GenericRow(Array[Any](2, 91, 2, "India")), - new GenericRow(Array[Any](2, 100, 2, "India")), - new GenericRow(Array[Any](3, 47, 3, "Norway")), - new GenericRow(Array[Any](4, 968, 4, "Oman")), - new GenericRow(Array[Any](5, 50, null, null)), - new GenericRow(Array[Any](6, null, 6, "Germany")) + new GenericRow(Array[Any](1, 354,"Iceland")), + new GenericRow(Array[Any](1, 354,"Ice-land")), + new GenericRow(Array[Any](2, 91,"India")), + new GenericRow(Array[Any](2, 100,"India")), + new GenericRow(Array[Any](3, 47,"Norway")), + new GenericRow(Array[Any](4, 968, "Oman")), + new GenericRow(Array[Any](5, 50, null)), + new GenericRow(Array[Any](6, null,"Germany")) ) results should contain theSameElementsAs expectedResults @@ -261,24 +254,22 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(emptyIdCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "outer") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), RddJoinParam(countryNames, Seq("col_0")), "outer") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, null, 1, "Iceland")), - new GenericRow(Array[Any](1, null, 1, "Ice-land")), - new GenericRow(Array[Any](2, null, 2, "India")), - new GenericRow(Array[Any](3, null, 3, "Norway")), - new GenericRow(Array[Any](4, null, 4, "Oman")), - new GenericRow(Array[Any](6, null, 6, "Germany")) + new GenericRow(Array[Any](1, null,"Iceland")), + new GenericRow(Array[Any](1, null,"Ice-land")), + new GenericRow(Array[Any](2, null,"India")), + new GenericRow(Array[Any](3, null, "Norway")), + new GenericRow(Array[Any](4, null,"Oman")), + new GenericRow(Array[Any](6, null,"Germany")) ) results should contain theSameElementsAs expectedResults @@ -289,24 +280,23 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(emptyIdCountryNames)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, "col_0"), - RddJoinParam(countryNames, "col_0"), "outer") + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0")), + RddJoinParam(countryNames, Seq("col_0")), "outer") val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_0_R", DataTypes.int32, 2), - Column("col_1_R", DataTypes.str, 3) + Column("col_1_R", DataTypes.str, 2) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, null, null)), - new GenericRow(Array[Any](2, 91, null, null)), - new GenericRow(Array[Any](2, 100, null, null)), - new GenericRow(Array[Any](3, 47, null, null)), - new GenericRow(Array[Any](4, 968, null, null)), - new GenericRow(Array[Any](5, 50, null, null)) + new GenericRow(Array[Any](1, 354, null)), + new GenericRow(Array[Any](2, 91, null)), + new GenericRow(Array[Any](2, 100, null)), + new GenericRow(Array[Any](3, 47, null)), + new GenericRow(Array[Any](4, 968, null)), + new GenericRow(Array[Any](5, 50, null)) ) results should contain theSameElementsAs expectedResults @@ -327,8 +317,8 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val rddFiveHundredThousandsToOneFiftyThousands = new FrameRdd(inputSchema, sparkContext.parallelize(fiftyThousandToOneFiftyThousands)) - val resultFrame = JoinRddFunctions.join(RddJoinParam(rddOneToMillion, "col_0"), - RddJoinParam(rddFiveHundredThousandsToOneFiftyThousands, "col_0"), "outer") + val resultFrame = JoinRddFunctions.join(RddJoinParam(rddOneToMillion, Seq("col_0")), + RddJoinParam(rddFiveHundredThousandsToOneFiftyThousands, Seq("col_0")), "outer") resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), From 7c20ce8cd9f1fc113477315a626627aacecffe59 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Wed, 6 Apr 2016 18:08:36 -0700 Subject: [PATCH 07/15] Join plugin with composite key support added --- conf/.gitignore | 4 +- .../join/BroadcastJoinRddFunctions.scala | 20 +- .../plugins/join/JoinBroadcastVariable.scala | 3 - .../frame/plugins/join/JoinRddFunctions.scala | 100 ++-- .../join/JoinBroadcastVariableITest.scala | 25 +- .../frame/plugins/join/SparkJoinITest.scala | 432 ++++++++++++++---- python-client/trustedanalytics/core/frame.py | 127 +++-- 7 files changed, 492 insertions(+), 219 deletions(-) diff --git a/conf/.gitignore b/conf/.gitignore index 9475c0774e..9fac4bc444 100644 --- a/conf/.gitignore +++ b/conf/.gitignore @@ -1,3 +1,5 @@ application.conf dev.conf -logback.xml \ No newline at end of file +logback.xml +application.conf.xavier3 +generated.conf.xavier3 \ No newline at end of file diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala index ca7ff83b34..0ba4ecd592 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala @@ -31,19 +31,15 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Perform left outer-join using a broadcast variable * * @param other join parameter for second data frame - * @return key-value RDD whose values are results of left-outer join + * @return key-value RDD whose values are results of left-outer join */ def leftBroadcastJoin(other: RddJoinParam): RDD[Row] = { val rightBroadcastVariable = JoinBroadcastVariable(other) lazy val rightNullRow: Row = new GenericRow(other.frame.numColumns) - - val rowWrapper = new RowWrapper(other.frame.frameSchema) - val rightColsToKeep = other.frame.frameSchema.dropColumns(other.joinColumns.toList).columnNames self.frame.flatMapRows(left => { - val leftKeys = left.values(self.joinColumns.toList) rightBroadcastVariable.get(leftKeys) match { - case Some(rightRowSet) => for (rightRow <- rightRowSet) yield Row.merge(left.row, new GenericRow(rowWrapper(rightRow).values(rightColsToKeep).toArray)) + case Some(rightRowSet) => for (rightRow <- rightRowSet) yield Row.merge(left.row, rightRow) case _ => List(Row.merge(left.row, rightNullRow.copy())) } @@ -54,19 +50,15 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Right outer-join using a broadcast variable * * @param other join parameter for second data frame - * @return key-value RDD whose values are results of right-outer join + * @return key-value RDD whose values are results of right-outer join */ def rightBroadcastJoin(other: RddJoinParam): RDD[Row] = { val leftBroadcastVariable = JoinBroadcastVariable(self) lazy val leftNullRow: Row = new GenericRow(self.frame.numColumns) - val rowWrapper = new RowWrapper(other.frame.frameSchema) - other.frame.flatMapRows(right => { - // val rightKey = right.value(other.joinColumn) - val leftColsToKeep = self.frame.frameSchema.dropColumns(self.joinColumns.toList).columnNames val rightKeys = right.values(other.joinColumns.toList) leftBroadcastVariable.get(rightKeys) match { - case Some(leftRowSet) => for (leftRow <- leftRowSet) yield Row.merge(new GenericRow(rowWrapper(leftRow).values(leftColsToKeep).toArray), right.row) + case Some(leftRowSet) => for (leftRow <- leftRowSet) yield Row.merge(leftRow, right.row) case _ => List(Row.merge(leftNullRow.copy(), right.row)) } }) @@ -76,7 +68,7 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali * Inner-join using a broadcast variable * * @param other join parameter for second data frame - * @return key-value RDD whose values are results of inner-outer join + * @return key-value RDD whose values are results of inner-outer join */ def innerBroadcastJoin(other: RddJoinParam, broadcastJoinThreshold: Long): RDD[Row] = { val leftSizeInBytes = self.estimatedSizeInBytes.getOrElse(Long.MaxValue) @@ -88,7 +80,6 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali val rightColsToKeep = other.frame.frameSchema.dropColumns(other.joinColumns.toList).columnNames self.frame.flatMapRows(left => { - // val leftKey = left.value(self.joinColumn) val leftKeys = left.values(self.joinColumns.toList) rightBroadcastVariable.get(leftKeys) match { case Some(rightRowSet) => @@ -100,7 +91,6 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali else if (leftSizeInBytes <= broadcastJoinThreshold) { val leftBroadcastVariable = JoinBroadcastVariable(self) other.frame.flatMapRows(rightRow => { - // val rightKey = rightRow.value(other.joinColumn) val leftColsToKeep = self.frame.frameSchema.dropColumns(self.joinColumns.toList).columnNames val rightKeys = rightRow.values(other.joinColumns.toList) leftBroadcastVariable.get(rightKeys) match { diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala index 89549c0626..3d6a445eff 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariable.scala @@ -52,10 +52,7 @@ case class JoinBroadcastVariable(joinParam: RddJoinParam) { // Create the broadcast variable for the join private def createBroadcastMultiMaps(joinParam: RddJoinParam): Broadcast[MultiMap[Any, Row]] = { //Grouping by key to ensure that duplicate keys are not split across different broadcast variables - //val broadcastList = joinParam.frame.groupByRows(row => row.value(joinParam.joinColumn)).collect().toList - val broadcastList = joinParam.frame.groupByRows(row => row.values(joinParam.joinColumns.toList)).collect().toList - val broadcastMultiMap = listToMultiMap(broadcastList) joinParam.frame.sparkContext.broadcast(broadcastMultiMap) } diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala index 69eb122196..7b76ba1b0d 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala @@ -93,6 +93,26 @@ object JoinRddFunctions extends Serializable { } } + /** + * expression maker helps for generating conditions to check when join invoked with composite keys + * + * @param leftFrame left data frame + * @param rightFrame rigth data frame + * @param leftJoinCols list of left frame column names used in join + * @param rightJoinCols list of right frame column name used in join + * @return + */ + def expressionMaker(leftFrame: DataFrame, rightFrame: DataFrame, leftJoinCols: Seq[String], rightJoinCols: Seq[String]): Column = { + val columnsTuple = leftJoinCols.zip(rightJoinCols) + + def makeExpression(leftCol: String, rightCol: String): Column = { + leftFrame(leftCol).equalTo(rightFrame(rightCol)) + } + var exps = makeExpression(columnsTuple.head._1, columnsTuple.head._2) + columnsTuple.tail.map { case (lc, rc) => exps = exps && makeExpression(lc, rc) } + exps + } + /** * Perform full-outer join * @@ -106,17 +126,9 @@ object JoinRddFunctions extends Serializable { def fullOuterJoin(left: RddJoinParam, right: RddJoinParam): RDD[Row] = { val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame - - val columnsTuple = left.joinColumns.zip(right.joinColumns) - - def makeExpression(leftCol: String, rightCol: String): Column = { - leftFrame(leftCol).equalTo(rightFrame(rightCol)) - } - var exps = makeExpression(columnsTuple.head._1, columnsTuple.head._2) - columnsTuple.tail.map { case (lc, rc) => exps = exps && makeExpression(lc, rc) } - + val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns) val joinedFrame = leftFrame.join(rightFrame, - exps, + expression, joinType = "fullouter" ) joinedFrame.rdd @@ -144,17 +156,9 @@ object JoinRddFunctions extends Serializable { case _ => val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame - - val columnsTuple = left.joinColumns.zip(right.joinColumns) - - def makeExpression(leftCol: String, rightCol: String): Column = { - leftFrame(leftCol).equalTo(rightFrame(rightCol)) - } - var exps = makeExpression(columnsTuple.head._1, columnsTuple.head._2) - columnsTuple.tail.map { case (lc, rc) => exps = exps && makeExpression(lc, rc) } - + val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns) val joinedFrame = leftFrame.join(rightFrame, - exps, + expression, joinType = "right" ) joinedFrame.rdd @@ -181,19 +185,12 @@ object JoinRddFunctions extends Serializable { case _ => val leftFrame = left.frame.toDataFrame val rightFrame = right.frame.toDataFrame - - val columnsTuple = left.joinColumns.zip(right.joinColumns) - - def makeExpression(leftCol: String, rightCol: String): Column = { - leftFrame(leftCol).equalTo(rightFrame(rightCol)) - } - var exps = makeExpression(columnsTuple.head._1, columnsTuple.head._2) - columnsTuple.tail.map { case (lc, rc) => exps = exps && makeExpression(lc, rc) } - + val expression = expressionMaker(leftFrame, rightFrame, left.joinColumns, right.joinColumns) val joinedFrame = leftFrame.join(rightFrame, - exps, + expression, joinType = "left" ) + joinedFrame.rdd } } @@ -216,13 +213,13 @@ object JoinRddFunctions extends Serializable { how match { case "outer" => { val mergedRdd = mergeJoinColumns(joinedRdd, left, right) - dropRightJoinColumn(mergedRdd, left, right) + dropRightJoinColumn(mergedRdd, left, right, how) } case "right" => { dropLeftJoinColumn(joinedRdd, left, right) } case _ => { - dropRightJoinColumn(joinedRdd, left, right) + dropRightJoinColumn(joinedRdd, left, right, how) } } } @@ -243,20 +240,15 @@ object JoinRddFunctions extends Serializable { val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema - - // val leftJoinIndex = leftSchema.columnIndex(left.joinColumn) - // val rightJoinIndex = rightSchema.columnIndex(right.joinColumn) + leftSchema.columns.size - val leftJoinIndices = leftSchema.columnIndices(left.joinColumns) val rightJoinIndices = rightSchema.columnIndices(right.joinColumns).map(rightindex => rightindex + leftSchema.columns.size) joinedRdd.map(row => { - val rowArray = row.toSeq.toArray leftJoinIndices.zip(rightJoinIndices).map { - case (leftindex, rightindex) => { - if (row.get(leftindex) == null) { - rowArray(leftindex) = row.get(rightindex) + case (leftIndex, rightIndex) => { + if (row.get(leftIndex) == null) { + rowArray(leftIndex) = row.get(rightIndex) } } } @@ -279,12 +271,10 @@ object JoinRddFunctions extends Serializable { right: RddJoinParam): FrameRdd = { val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema - - // Create new schema - val newLeftSchema = leftSchema.dropColumns(left.joinColumns.toList) - val newSchema = FrameSchema(Schema.join(newLeftSchema.columns, rightSchema.columns)) - - new FrameRdd(newSchema, joinedRdd) + val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) + val frameRdd = new FrameRdd(newSchema, joinedRdd) + val leftColNames = right.joinColumns.map(col => col + "_L") + frameRdd.dropColumns(leftColNames.toList) } /** @@ -299,15 +289,23 @@ object JoinRddFunctions extends Serializable { */ def dropRightJoinColumn(joinedRdd: RDD[Row], left: RddJoinParam, - right: RddJoinParam): FrameRdd = { + right: RddJoinParam, + how: String): FrameRdd = { val leftSchema = left.frame.frameSchema val rightSchema = right.frame.frameSchema // Create new schema - val newRightSchema = rightSchema.dropColumns(right.joinColumns.toList) - val newSchema = FrameSchema(Schema.join(leftSchema.columns, newRightSchema.columns)) - - new FrameRdd(newSchema, joinedRdd) + if (how == "inner") { + val newRightSchema = rightSchema.dropColumns(right.joinColumns.toList) + val newSchema = FrameSchema(Schema.join(leftSchema.columns, newRightSchema.columns)) + new FrameRdd(newSchema, joinedRdd) + } + else { + val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) + val frameRdd = new FrameRdd(newSchema, joinedRdd) + val rightColNames = right.joinColumns.map(col => col + "_R") + frameRdd.dropColumns(rightColNames.toList) + } } } diff --git a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala index 95b7a0a95d..b65f3788e9 100644 --- a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala +++ b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinBroadcastVariableITest.scala @@ -46,14 +46,15 @@ class JoinBroadcastVariableITest extends TestingSparkContextFlatSpec with Matche val broadcastVariable = JoinBroadcastVariable(joinParam) broadcastVariable.broadcastMultiMap.value.size should equal(5) - broadcastVariable.get(1).get should contain theSameElementsAs Set(idCountryNames(0), idCountryNames(1)) - broadcastVariable.get(2).get should contain theSameElementsAs Set(idCountryNames(2)) - broadcastVariable.get(3).get should contain theSameElementsAs Set(idCountryNames(3)) - broadcastVariable.get(4).get should contain theSameElementsAs Set(idCountryNames(4)) - broadcastVariable.get(6).get should contain theSameElementsAs Set(idCountryNames(5)) - broadcastVariable.get(8).isDefined should equal(false) + broadcastVariable.get(List(1)).get should contain theSameElementsAs Set(idCountryNames(0), idCountryNames(1)) + broadcastVariable.get(List(2)).get should contain theSameElementsAs Set(idCountryNames(2)) + broadcastVariable.get(List(3)).get should contain theSameElementsAs Set(idCountryNames(3)) + broadcastVariable.get(List(4)).get should contain theSameElementsAs Set(idCountryNames(4)) + broadcastVariable.get(List(6)).get should contain theSameElementsAs Set(idCountryNames(5)) + broadcastVariable.get(List(8)).isDefined should equal(false) } + "JoinBroadcastVariable" should "create a two broadcast variables when RDD size is equals 3GB" in { val countryNames = new FrameRdd(inputSchema, sparkContext.parallelize(idCountryNames)) @@ -61,12 +62,12 @@ class JoinBroadcastVariableITest extends TestingSparkContextFlatSpec with Matche val broadcastVariable = JoinBroadcastVariable(joinParam) - broadcastVariable.get(1).get should contain theSameElementsAs Set(idCountryNames(0), idCountryNames(1)) - broadcastVariable.get(2).get should contain theSameElementsAs Set(idCountryNames(2)) - broadcastVariable.get(3).get should contain theSameElementsAs Set(idCountryNames(3)) - broadcastVariable.get(4).get should contain theSameElementsAs Set(idCountryNames(4)) - broadcastVariable.get(6).get should contain theSameElementsAs Set(idCountryNames(5)) - broadcastVariable.get(8).isDefined should equal(false) + broadcastVariable.get(List(1)).get should contain theSameElementsAs Set(idCountryNames(0), idCountryNames(1)) + broadcastVariable.get(List(2)).get should contain theSameElementsAs Set(idCountryNames(2)) + broadcastVariable.get(List(3)).get should contain theSameElementsAs Set(idCountryNames(3)) + broadcastVariable.get(List(4)).get should contain theSameElementsAs Set(idCountryNames(4)) + broadcastVariable.get(List(6)).get should contain theSameElementsAs Set(idCountryNames(5)) + broadcastVariable.get(List(8)).isDefined should equal(false) } "JoinBroadcastVariable" should "create an empty broadcast variable" in { diff --git a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala index bfd0dfb845..f0c285cf15 100644 --- a/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala +++ b/engine-plugins/frame-plugins/src/test/scala/org/trustedanalytics/atk/engine/frame/plugins/join/SparkJoinITest.scala @@ -26,30 +26,32 @@ import org.trustedanalytics.atk.testutils.TestingSparkContextFlatSpec class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { // Test data has duplicate keys, matching and non-matching keys val idCountryCodes: List[Row] = List( - new GenericRow(Array[Any](1, 354)), - new GenericRow(Array[Any](2, 91)), - new GenericRow(Array[Any](2, 100)), - new GenericRow(Array[Any](3, 47)), - new GenericRow(Array[Any](4, 968)), - new GenericRow(Array[Any](5, 50))) + new GenericRow(Array[Any](1, 354, "a")), + new GenericRow(Array[Any](2, 91, "a")), + new GenericRow(Array[Any](2, 100, "b")), + new GenericRow(Array[Any](3, 47, "a")), + new GenericRow(Array[Any](4, 968, "c")), + new GenericRow(Array[Any](5, 50, "c"))) val idCountryNames: List[Row] = List( - new GenericRow(Array[Any](1, "Iceland")), - new GenericRow(Array[Any](1, "Ice-land")), - new GenericRow(Array[Any](2, "India")), - new GenericRow(Array[Any](3, "Norway")), - new GenericRow(Array[Any](4, "Oman")), - new GenericRow(Array[Any](6, "Germany")) + new GenericRow(Array[Any](1, "Iceland", "a")), + new GenericRow(Array[Any](1, "Ice-land", "a")), + new GenericRow(Array[Any](2, "India", "b")), + new GenericRow(Array[Any](3, "Norway", "a")), + new GenericRow(Array[Any](4, "Oman", "c")), + new GenericRow(Array[Any](6, "Germany", "c")) ) val codeSchema = FrameSchema(List( Column("col_0", DataTypes.int32), - Column("col_1", DataTypes.int32) + Column("col_1", DataTypes.int32), + Column("col_2", DataTypes.str) )) val countrySchema = FrameSchema(List( Column("col_0", DataTypes.int32), - Column("col_1", DataTypes.str) + Column("col_1", DataTypes.str), + Column("col_2", DataTypes.str) )) "joinRDDs" should "join two RDD with inner join" in { @@ -64,16 +66,18 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, "Iceland")), - new GenericRow(Array[Any](1, 354, "Ice-land")), - new GenericRow(Array[Any](2, 91, "India")), - new GenericRow(Array[Any](2, 100, "India")), - new GenericRow(Array[Any](3, 47, "Norway")), - new GenericRow(Array[Any](4, 968, "Oman")) + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")) ) results should contain theSameElementsAs expectedResults @@ -92,16 +96,73 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { resultFrame.frameSchema.columns should equal(List( Column("col_0", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, "Iceland")), - new GenericRow(Array[Any](1, 354, "Ice-land")), - new GenericRow(Array[Any](2, 91, "India")), - new GenericRow(Array[Any](2, 100, "India")), - new GenericRow(Array[Any](3, 47, "Norway")), - new GenericRow(Array[Any](4, 968, "Oman")) + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with inner join" in { + + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0", "col_2")), + RddJoinParam(countryNames, Seq("col_0", "col_2")), "inner") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with inner join using broadcast variable" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0", "col_2"), Some(150)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0", "col_2"), Some(10000)) + + val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "inner") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")) ) results should contain theSameElementsAs expectedResults @@ -116,19 +177,21 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354,"Iceland")), - new GenericRow(Array[Any](1, 354,"Ice-land")), - new GenericRow(Array[Any](2, 91,"India")), - new GenericRow(Array[Any](2, 100,"India")), - new GenericRow(Array[Any](3, 47,"Norway")), - new GenericRow(Array[Any](4, 968,"Oman")), - new GenericRow(Array[Any](5, 50, null)) + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")), + new GenericRow(Array[Any](5, 50, "c", null, null)) ) results should contain theSameElementsAs expectedResults @@ -146,23 +209,85 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")), + new GenericRow(Array[Any](5, 50, "c", null, null)) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with left join" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0", "col_2")), + RddJoinParam(countryNames, Seq("col_0", "col_2")), "left") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0_L", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 91, "a", null)), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")), + new GenericRow(Array[Any](5, 50, "c", null)) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with left join using broadcast variable" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0", "col_2"), Some(1500L)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0", "col_2"), Some(100L + Int.MaxValue)) + + // Test join wrapper function + val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "left") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354,"Iceland")), - new GenericRow(Array[Any](1, 354, "Ice-land")), - new GenericRow(Array[Any](2, 91, "India")), - new GenericRow(Array[Any](2, 100, "India")), - new GenericRow(Array[Any](3, 47, "Norway")), - new GenericRow(Array[Any](4, 968, "Oman")), - new GenericRow(Array[Any](5, 50, null)) + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 91, "a", null)), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")), + new GenericRow(Array[Any](5, 50, "c", null)) ) results should contain theSameElementsAs expectedResults } + "joinRDDs" should "join two RDD with right join" in { val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) @@ -174,18 +299,20 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { resultFrame.frameSchema.columns should equal(List( Column("col_1_L", DataTypes.int32, 0), - Column("col_0", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 1), + Column("col_0_R", DataTypes.int32, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](354, 1, "Iceland")), - new GenericRow(Array[Any](354, 1, "Ice-land")), - new GenericRow(Array[Any](91, 2, "India")), - new GenericRow(Array[Any](100, 2, "India")), - new GenericRow(Array[Any](47, 3, "Norway")), - new GenericRow(Array[Any](968, 4, "Oman")), - new GenericRow(Array[Any](null, 6, "Germany")) + new GenericRow(Array[Any](354, "a", 1, "Iceland", "a")), + new GenericRow(Array[Any](354, "a", 1, "Ice-land", "a")), + new GenericRow(Array[Any](91, "a", 2, "India", "b")), + new GenericRow(Array[Any](100, "b", 2, "India", "b")), + new GenericRow(Array[Any](47, "a", 3, "Norway", "a")), + new GenericRow(Array[Any](968, "c", 4, "Oman", "c")), + new GenericRow(Array[Any](null, null, 6, "Germany", "c")) ) results should contain theSameElementsAs expectedResults @@ -204,18 +331,78 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { resultFrame.frameSchema.columns should equal(List( Column("col_1_L", DataTypes.int32, 0), - Column("col_0", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 1), + Column("col_0_R", DataTypes.int32, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](354, "a", 1, "Iceland", "a")), + new GenericRow(Array[Any](354, "a", 1, "Ice-land", "a")), + new GenericRow(Array[Any](91, "a", 2, "India", "b")), + new GenericRow(Array[Any](100, "b", 2, "India", "b")), + new GenericRow(Array[Any](47, "a", 3, "Norway", "a")), + new GenericRow(Array[Any](968, "c", 4, "Oman", "c")), + new GenericRow(Array[Any](null, null, 6, "Germany", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with right join" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join( + RddJoinParam(countryCode, Seq("col_0", "col_2")), + RddJoinParam(countryNames, Seq("col_0", "col_2")), "right") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_1_L", DataTypes.int32, 0), + Column("col_0_R", DataTypes.int32, 1), + Column("col_1_R", DataTypes.str, 2), + Column("col_2_R", DataTypes.str, 3) + )) + + val expectedResults = List( + new GenericRow(Array[Any](354, 1, "Iceland", "a")), + new GenericRow(Array[Any](354, 1, "Ice-land", "a")), + new GenericRow(Array[Any](100, 2, "India", "b")), + new GenericRow(Array[Any](47, 3, "Norway", "a")), + new GenericRow(Array[Any](968, 4, "Oman", "c")), + new GenericRow(Array[Any](null, 6, "Germany", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with right join using broadcast variable" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val broadcastJoinThreshold = 1000 + val leftJoinParam = RddJoinParam(countryCode, Seq("col_0", "col_2"), Some(800)) + val rightJoinParam = RddJoinParam(countryNames, Seq("col_0", "col_2"), Some(4000)) + + val resultFrame = JoinRddFunctions.join(leftJoinParam, rightJoinParam, "right", broadcastJoinThreshold) + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_1_L", DataTypes.int32, 0), + Column("col_0_R", DataTypes.int32, 1), + Column("col_1_R", DataTypes.str, 2), + Column("col_2_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](354, 1, "Iceland")), - new GenericRow(Array[Any](354, 1, "Ice-land")), - new GenericRow(Array[Any](91, 2, "India")), - new GenericRow(Array[Any](100, 2, "India")), - new GenericRow(Array[Any](47, 3, "Norway")), - new GenericRow(Array[Any](968, 4, "Oman")), - new GenericRow(Array[Any](null, 6, "Germany")) + new GenericRow(Array[Any](354, 1, "Iceland", "a")), + new GenericRow(Array[Any](354, 1, "Ice-land", "a")), + new GenericRow(Array[Any](100, 2, "India", "b")), + new GenericRow(Array[Any](47, 3, "Norway", "a")), + new GenericRow(Array[Any](968, 4, "Oman", "c")), + new GenericRow(Array[Any](null, 6, "Germany", "c")) ) results should contain theSameElementsAs expectedResults @@ -230,20 +417,51 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, 354, "a", "Iceland", "a")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land", "a")), + new GenericRow(Array[Any](2, 91, "a", "India", "b")), + new GenericRow(Array[Any](2, 100, "b", "India", "b")), + new GenericRow(Array[Any](3, 47, "a", "Norway", "a")), + new GenericRow(Array[Any](4, 968, "c", "Oman", "c")), + new GenericRow(Array[Any](5, 50, "c", null, null)), + new GenericRow(Array[Any](6, null, null, "Germany", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeJoinRDDs" should "join two RDD with outer join" in { + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(idCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0", "col_2")), + RddJoinParam(countryNames, Seq("col_0", "col_2")), "outer") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354,"Iceland")), - new GenericRow(Array[Any](1, 354,"Ice-land")), - new GenericRow(Array[Any](2, 91,"India")), - new GenericRow(Array[Any](2, 100,"India")), - new GenericRow(Array[Any](3, 47,"Norway")), - new GenericRow(Array[Any](4, 968, "Oman")), - new GenericRow(Array[Any](5, 50, null)), - new GenericRow(Array[Any](6, null,"Germany")) + new GenericRow(Array[Any](1, 354, "a", "Iceland")), + new GenericRow(Array[Any](1, 354, "a", "Ice-land")), + new GenericRow(Array[Any](2, 91, "a", null)), + new GenericRow(Array[Any](2, 100, "b", "India")), + new GenericRow(Array[Any](3, 47, "a", "Norway")), + new GenericRow(Array[Any](4, 968, "c", "Oman")), + new GenericRow(Array[Any](5, 50, "c", null)), + new GenericRow(Array[Any](6, null, "c", "Germany")) ) results should contain theSameElementsAs expectedResults @@ -258,18 +476,47 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](1, null,"Iceland")), - new GenericRow(Array[Any](1, null,"Ice-land")), - new GenericRow(Array[Any](2, null,"India")), - new GenericRow(Array[Any](3, null, "Norway")), - new GenericRow(Array[Any](4, null,"Oman")), - new GenericRow(Array[Any](6, null,"Germany")) + new GenericRow(Array[Any](1, null, null, "Iceland", "a")), + new GenericRow(Array[Any](1, null, null, "Ice-land", "a")), + new GenericRow(Array[Any](2, null, null, "India", "b")), + new GenericRow(Array[Any](3, null, null, "Norway", "a")), + new GenericRow(Array[Any](4, null, null, "Oman", "c")), + new GenericRow(Array[Any](6, null, null, "Germany", "c")) + ) + + results should contain theSameElementsAs expectedResults + } + + "compositeOuter join with empty left RDD" should "preserve the result from the right RDD" in { + val emptyIdCountryCodes = List.empty[Row] + val countryCode = new FrameRdd(codeSchema, sparkContext.parallelize(emptyIdCountryCodes)) + val countryNames = new FrameRdd(countrySchema, sparkContext.parallelize(idCountryNames)) + + val resultFrame = JoinRddFunctions.join(RddJoinParam(countryCode, Seq("col_0", "col_2")), RddJoinParam(countryNames, Seq("col_0", "col_2")), "outer") + val results = resultFrame.collect() + + resultFrame.frameSchema.columns should equal(List( + Column("col_0_L", DataTypes.int32, 0), + Column("col_1_L", DataTypes.int32, 1), + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3) + )) + + val expectedResults = List( + new GenericRow(Array[Any](1, null, "a", "Iceland")), + new GenericRow(Array[Any](1, null, "a", "Ice-land")), + new GenericRow(Array[Any](2, null, "b", "India")), + new GenericRow(Array[Any](3, null, "a", "Norway")), + new GenericRow(Array[Any](4, null, "c", "Oman")), + new GenericRow(Array[Any](6, null, "c", "Germany")) ) results should contain theSameElementsAs expectedResults @@ -285,18 +532,20 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { val results = resultFrame.collect() resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), + Column("col_0_L", DataTypes.int32, 0), Column("col_1_L", DataTypes.int32, 1), - Column("col_1_R", DataTypes.str, 2) + Column("col_2_L", DataTypes.str, 2), + Column("col_1_R", DataTypes.str, 3), + Column("col_2_R", DataTypes.str, 4) )) val expectedResults = List( - new GenericRow(Array[Any](1, 354, null)), - new GenericRow(Array[Any](2, 91, null)), - new GenericRow(Array[Any](2, 100, null)), - new GenericRow(Array[Any](3, 47, null)), - new GenericRow(Array[Any](4, 968, null)), - new GenericRow(Array[Any](5, 50, null)) + new GenericRow(Array[Any](1, 354, "a", null, null)), + new GenericRow(Array[Any](2, 91, "a", null, null)), + new GenericRow(Array[Any](2, 100, "b", null, null)), + new GenericRow(Array[Any](3, 47, "a", null, null)), + new GenericRow(Array[Any](4, 968, "c", null, null)), + new GenericRow(Array[Any](5, 50, "c", null, null)) ) results should contain theSameElementsAs expectedResults @@ -321,8 +570,7 @@ class SparkJoinITest extends TestingSparkContextFlatSpec with Matchers { RddJoinParam(rddFiveHundredThousandsToOneFiftyThousands, Seq("col_0")), "outer") resultFrame.frameSchema.columns should equal(List( - Column("col_0", DataTypes.int32, 0), - Column("col_0_R", DataTypes.int32, 1) + Column("col_0_L", DataTypes.int32, 0) )) resultFrame.count shouldBe 150000 } diff --git a/python-client/trustedanalytics/core/frame.py b/python-client/trustedanalytics/core/frame.py index 58f1194546..d6f2c13683 100644 --- a/python-client/trustedanalytics/core/frame.py +++ b/python-client/trustedanalytics/core/frame.py @@ -1089,21 +1089,21 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): Create a new frame from a SQL JOIN operation with another frame. The frame on the 'left' is the currently active frame. The frame on the 'right' is another frame. - This method takes a column in the left frame and matches its values - with a column in the right frame. + This method take column(s) in the left frame and matches its values + with column(s) in the right frame. Using the default 'how' option ['inner'] will only allow data in the resultant frame if both the left and right frames have the same value - in the matching column. + in the matching column(s). Using the 'left' 'how' option will allow any data in the resultant frame if it exists in the left frame, but will allow any data from the - right frame if it has a value in its column which matches the value in - the left frame column. + right frame if it has a value in its column(s) which matches the value in + the left frame column(s). Using the 'right' option works similarly, except it keeps all the data from the right frame and only the data from the left frame when it matches. The 'outer' option provides a frame with data from both frames where the left and right frames did not have the same value in the matching - column. + column(s). Notes ----- @@ -1134,6 +1134,20 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): >>> colors = ta.Frame(ta.UploadRows([[1, 'red'], [2, 'yellow'], [3, 'green'], [4, 'blue']], [('numbers', ta.int32), ('color', str)])) -etc- + >>> country_code_rows = [[1, 354, "a"],[2, 91, "a"],[2, 100, "b"],[3, 47, "a"],[4, 968, "c"],[5, 50, "c"]] + >>> country_code_schema = [("col_0", int),("col_1", int),("col_2",str)] + -etc- + + >>> country_name_rows = [[1, "Iceland", "a"],[1, "Ice-land", "a"],[2, "India", "b"],[3, "Norway", "a"],[4, "Oman", "c"],[6, "Germany", "c"]] + >>> country_names_schema = [("col_0", int),("col_1", str),("col_2",str)] + -etc- + + >>> country_codes_frame = ta.Frame(ta.UploadRows(country_code_rows, country_code_schema)) + -etc- + + >>> country_names_frame= ta.Frame(ta.UploadRows(country_name_rows, country_names_schema)) + -etc- + Consider two frames: codes and colors @@ -1162,15 +1176,10 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): Join them on the 'numbers' column ('inner' join by default) - >>> j = codes.join(colors, 'numbers') + >>> j = codes.join(colors, ['numbers']) - >>> j.inspect() - - - >>> j.inspect(columns=['numbers', 'color']) - [#] numbers color ==================== [0] 1 red @@ -1184,48 +1193,76 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): Try a 'left' join, which includes all the rows of the codes frame. - >>> j_left = codes.join(colors, 'numbers', how='left') + >>> j_left = codes.join(colors, ['numbers'], how='left') - >>> j_left.inspect() - - - >>> j_left.inspect(columns=['numbers', 'color']) - - [#] numbers color - ==================== - [0] 1 red - [1] 3 green - [2] 1 red - [3] 0 None - [4] 2 yellow - [5] 1 red - [6] 5 None - [7] 3 green + [#] numbers_L color + ====================== + [0] 1 red + [1] 3 green + [2] 1 red + [3] 0 None + [4] 2 yellow + [5] 1 red + [6] 5 None + [7] 3 green + And an outer join: - >>> j_outer = codes.join(colors, 'numbers', how='outer') + >>> j_outer = codes.join(colors, ['numbers'], how='outer') - >>> j_outer.inspect() - - - >>> j_outer.inspect(columns=['numbers', 'color']) - - [#] numbers color - ==================== - [0] 0 None - [1] 1 red - [2] 1 red - [3] 1 red - [4] 2 yellow - [5] 3 green - [6] 3 green - [7] 4 blue - [8] 5 None + [#] numbers_L color + ====================== + [0] 0 None + [1] 1 red + [2] 1 red + [3] 1 red + [4] 2 yellow + [5] 3 green + [6] 3 green + [7] 4 blue + [8] 5 None + + Consider two frames: country_codes_frame and country_names_frame + + >>> country_codes_frame.inspect() + [#] col_0 col_1 col_2 + ======================== + [0] 1 354 a + [1] 2 91 a + [2] 2 100 b + [3] 3 47 a + [4] 4 968 c + [5] 5 50 c + + + >>> country_names_frame.inspect() + [#] col_0 col_1 col_2 + =========================== + [0] 1 Iceland a + [1] 1 Ice-land a + [2] 2 India b + [3] 3 Norway a + [4] 4 Oman c + [5] 6 Germany c + + Join them on the 'col_0' and 'col_2' columns ('inner' join by default) + + >>> composite_join = country_codes_frame.join(country_names_frame, ['col_0', 'col_2']) + + + >>> composite_join.inspect() + [#] col_0 col_1_L col_2 col_1_R + ==================================== + [0] 1 354 a Ice-land + [1] 1 354 a Iceland + [2] 2 100 b India + [3] 3 47 a Norway + [4] 4 968 c Oman More examples can be found in the :ref:`user manual `. From 578dbe7a85b12b2d6c8149b7304575be7df78253 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Wed, 6 Apr 2016 18:38:06 -0700 Subject: [PATCH 08/15] Removed timer comments which are placed for group_by --- .../engine/frame/plugins/UnflattenColumnsPlugin.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala index 23defad426..d3ce5c06d1 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/UnflattenColumnsPlugin.scala @@ -65,19 +65,10 @@ class UnflattenColumnPlugin extends SparkCommandPlugin[UnflattenColumnArgs, Unit val schema = frame.schema val compositeKeyNames = arguments.columns val compositeKeyIndices = compositeKeyNames.map(schema.columnIndex) - // run the operation val targetSchema = UnflattenColumnFunctions.createTargetSchema(schema, compositeKeyNames) - - //added timer for unflatten - println(s"Row Count before Unflatten groupby ${frame.rdd.count()}") - val start = System.nanoTime() val initialRdd = frame.rdd.groupByRows(row => row.values(compositeKeyNames)) - val end = System.nanoTime() - initialRdd.count() - println(s"Unflatten Groupby Time ${end - start}") val resultRdd = UnflattenColumnFunctions.unflattenRddByCompositeKey(compositeKeyIndices, initialRdd, targetSchema, arguments.delimiter.getOrElse(defaultDelimiter)) - frame.save(new FrameRdd(targetSchema, resultRdd)) } From 61185c5d8c9575fb0d0e55d3e13fc441765f6d1f Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Thu, 7 Apr 2016 10:36:56 -0700 Subject: [PATCH 09/15] removed uncessary lines --- .../atk/engine/frame/PythonRddStorage.scala | 42 ++++--------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala index 8fe1245fea..baeb6117bd 100644 --- a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala +++ b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala @@ -74,13 +74,8 @@ object PythonRddStorage { udfSchema } val converter = DataTypes.parseMany(newSchema.columns.map(_.dataType).toArray)(_) - val accumulatorSer = sc.accumulator(0L, "mytimerSerGeneric") - val accumulatorDeSer = sc.accumulator(0L, "mytimerDeSerGeneric") - val pyRdd = rddToPyRdd(udf, data, sc, accumulatorSer) - val frameRdd = getRddFromPythonRdd(pyRdd, converter, accumulatorDeSer) - println(s"MytimerSer in mapWith took ${accumulatorSer.value}") - frameRdd.count() - println(s"MytimerDeSer in mapWith took ${accumulatorDeSer.value}") + val pyRdd = rddToPyRdd(udf, data, sc) + val frameRdd = getRddFromPythonRdd(pyRdd, converter) FrameRdd.toFrameRdd(newSchema, frameRdd) } @@ -99,17 +94,9 @@ object PythonRddStorage { //track key indices to fetch data during BSON decode. //val keyIndices = for (key <- aggregateByColumnKeys) yield data.frameSchema.columnIndex(key) val converter = DataTypes.parseMany(keyedSchema.columns.map(_.dataType).toArray)(_) - val groupRDD = data.groupByRows(row => row.values(aggregateByColumnKeys)) - - val accumulatorSer = sc.accumulator(0L, "mytimerSerAggregated") - val accumulatorDeSer = sc.accumulator(0L, "mytimerDeSerAggregated") - val pyRdd = aggregateRddToPyRdd(udf, groupRDD, sc, accumulatorSer) - val frameRdd = getRddFromPythonRdd(pyRdd, converter, accumulatorDeSer) - //serialization timer - println(s"MytimerSer in AggregateUDF took ${accumulatorSer.value}") - frameRdd.count() - println(s"MytimerDeSer in AggregateUDF took ${accumulatorDeSer.value}") + val pyRdd = aggregateRddToPyRdd(udf, groupRDD, sc) + val frameRdd = getRddFromPythonRdd(pyRdd, converter) FrameRdd.toFrameRdd(keyedSchema, frameRdd) } @@ -139,12 +126,11 @@ object PythonRddStorage { bsonList } - def rddToPyRdd(udf: Udf, rdd: RDD[Row], sc: SparkContext, acc: Accumulator[Long] = null): EnginePythonRdd[Array[Byte]] = { + def rddToPyRdd(udf: Udf, rdd: RDD[Row], sc: SparkContext): EnginePythonRdd[Array[Byte]] = { val predicateInBytes = decodePythonBase64EncodedStrToBytes(udf.function) // Create an RDD of byte arrays representing bson objects val baseRdd: RDD[Array[Byte]] = rdd.map( x => { - val start = System.nanoTime() val obj = new BasicBSONObject() obj.put("array", x.toSeq.toArray.map { case y: ArrayBuffer[_] => iterableToBsonList(y) @@ -153,13 +139,9 @@ object PythonRddStorage { case value => value }) val res = BSON.encode(obj) - println(s"Bson Encoded obj: ${res}") - if (acc != null) - acc += (System.nanoTime() - start) res } ) - println(s"RddToPyRddGeneric Bytes ${baseRdd.first()} ${baseRdd.first().length}") val pyRdd = getPyRdd(udf, sc, baseRdd, predicateInBytes) pyRdd } @@ -220,11 +202,10 @@ object PythonRddStorage { * @param rdd rdd(List[keys], List[Rows]) * @return PythonRdd */ - def aggregateRddToPyRdd(udf: Udf, rdd: RDD[(List[Any], Iterable[Row])], sc: SparkContext, acc: Accumulator[Long] = null): EnginePythonRdd[Array[Byte]] = { + def aggregateRddToPyRdd(udf: Udf, rdd: RDD[(List[Any], Iterable[Row])], sc: SparkContext): EnginePythonRdd[Array[Byte]] = { val predicateInBytes = decodePythonBase64EncodedStrToBytes(udf.function) val baseRdd: RDD[Array[Byte]] = rdd.map { case (key, rows) => { - val x = System.nanoTime() val obj = new BasicBSONObject() val bsonRows = rows.map( row => { @@ -238,14 +219,9 @@ object PythonRddStorage { //obj.put("keyindices", keyIndices.toArray) obj.put("array", bsonRows) val res = BSON.encode(obj) - println(s"Bson Encoded obj: ${res}") - val y = System.nanoTime() - if (acc != null) - acc += (y - x) res } } - println(s"agg12-RddToPyRddAggregation Bytes ${baseRdd.first()} ${baseRdd.first().length}") val pyRdd = getPyRdd(udf, sc, baseRdd, predicateInBytes) pyRdd } @@ -271,9 +247,8 @@ object PythonRddStorage { result.headOption } - def getRddFromPythonRdd(pyRdd: EnginePythonRdd[Array[Byte]], converter: (Array[Any] => Array[Any]) = null, acc: Accumulator[Long] = null): RDD[Array[Any]] = { + def getRddFromPythonRdd(pyRdd: EnginePythonRdd[Array[Byte]], converter: (Array[Any] => Array[Any]) = null): RDD[Array[Any]] = { val resultRdd = pyRdd.flatMap(s => { - val start = System.nanoTime() //should be BasicBSONList containing only BasicBSONList objects val bson = BSON.decode(s) val asList = bson.get("array").asInstanceOf[BasicBSONList] @@ -284,9 +259,6 @@ object PythonRddStorage { case value => value }.toArray.asInstanceOf[Array[Any]] }) - val end = System.nanoTime() - if (acc != null) - acc += (end - start) res }).map(converter) From 06e45652af39c3f4bbab38852a909ed2cad5cdf8 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Thu, 7 Apr 2016 10:38:24 -0700 Subject: [PATCH 10/15] removed lines in .gitignore --- conf/.gitignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/conf/.gitignore b/conf/.gitignore index 9fac4bc444..9475c0774e 100644 --- a/conf/.gitignore +++ b/conf/.gitignore @@ -1,5 +1,3 @@ application.conf dev.conf -logback.xml -application.conf.xavier3 -generated.conf.xavier3 \ No newline at end of file +logback.xml \ No newline at end of file From 0f69929727f342c1a69f6f555549aa882f960b78 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Thu, 7 Apr 2016 11:39:29 -0700 Subject: [PATCH 11/15] Added changes as per britions feedback --- .../plugins/join/BroadcastJoinRddFunctions.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala index 0ba4ecd592..16b87baf96 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/BroadcastJoinRddFunctions.scala @@ -36,8 +36,9 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali def leftBroadcastJoin(other: RddJoinParam): RDD[Row] = { val rightBroadcastVariable = JoinBroadcastVariable(other) lazy val rightNullRow: Row = new GenericRow(other.frame.numColumns) + val leftJoinColumns = self.joinColumns.toList self.frame.flatMapRows(left => { - val leftKeys = left.values(self.joinColumns.toList) + val leftKeys = left.values(leftJoinColumns) rightBroadcastVariable.get(leftKeys) match { case Some(rightRowSet) => for (rightRow <- rightRowSet) yield Row.merge(left.row, rightRow) case _ => List(Row.merge(left.row, rightNullRow.copy())) @@ -55,8 +56,9 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali def rightBroadcastJoin(other: RddJoinParam): RDD[Row] = { val leftBroadcastVariable = JoinBroadcastVariable(self) lazy val leftNullRow: Row = new GenericRow(self.frame.numColumns) + val rightJoinColumns = other.joinColumns.toList other.frame.flatMapRows(right => { - val rightKeys = right.values(other.joinColumns.toList) + val rightKeys = right.values(rightJoinColumns) leftBroadcastVariable.get(rightKeys) match { case Some(leftRowSet) => for (leftRow <- leftRowSet) yield Row.merge(leftRow, right.row) case _ => List(Row.merge(leftNullRow.copy(), right.row)) @@ -79,8 +81,9 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali val rightBroadcastVariable = JoinBroadcastVariable(other) val rightColsToKeep = other.frame.frameSchema.dropColumns(other.joinColumns.toList).columnNames + val leftJoinColumns = self.joinColumns.toList self.frame.flatMapRows(left => { - val leftKeys = left.values(self.joinColumns.toList) + val leftKeys = left.values(leftJoinColumns) rightBroadcastVariable.get(leftKeys) match { case Some(rightRowSet) => for (rightRow <- rightRowSet) yield Row.merge(left.row, new GenericRow(rowWrapper(rightRow).values(rightColsToKeep).toArray)) @@ -90,9 +93,10 @@ class BroadcastJoinRddFunctions(self: RddJoinParam) extends Logging with Seriali } else if (leftSizeInBytes <= broadcastJoinThreshold) { val leftBroadcastVariable = JoinBroadcastVariable(self) + val rightJoinColumns = other.joinColumns.toList other.frame.flatMapRows(rightRow => { val leftColsToKeep = self.frame.frameSchema.dropColumns(self.joinColumns.toList).columnNames - val rightKeys = rightRow.values(other.joinColumns.toList) + val rightKeys = rightRow.values(rightJoinColumns) leftBroadcastVariable.get(rightKeys) match { case Some(leftRowSet) => for (leftRow <- leftRowSet) yield Row.merge(new GenericRow(rowWrapper(leftRow).values(leftColsToKeep).toArray), rightRow.row) From 08e646388675dbf31e51d3cdde38d81cfde73f23 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Thu, 7 Apr 2016 15:17:28 -0700 Subject: [PATCH 12/15] feedback changes --- .../atk/engine/frame/plugins/join/JoinRddFunctions.scala | 7 +++---- python-client/trustedanalytics/core/frame.py | 6 +++--- python-client/trustedanalytics/rest/frame.py | 8 ++++++++ 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala index 7b76ba1b0d..1eaeabe695 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala @@ -108,9 +108,8 @@ object JoinRddFunctions extends Serializable { def makeExpression(leftCol: String, rightCol: String): Column = { leftFrame(leftCol).equalTo(rightFrame(rightCol)) } - var exps = makeExpression(columnsTuple.head._1, columnsTuple.head._2) - columnsTuple.tail.map { case (lc, rc) => exps = exps && makeExpression(lc, rc) } - exps + val expression = columnsTuple.map { case (lc, rc) => makeExpression(lc, rc) }.reduce(_ && _) + expression } /** @@ -273,7 +272,7 @@ object JoinRddFunctions extends Serializable { val rightSchema = right.frame.frameSchema val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) val frameRdd = new FrameRdd(newSchema, joinedRdd) - val leftColNames = right.joinColumns.map(col => col + "_L") + val leftColNames = left.joinColumns.map(col => col + "_L") frameRdd.dropColumns(leftColNames.toList) } diff --git a/python-client/trustedanalytics/core/frame.py b/python-client/trustedanalytics/core/frame.py index d6f2c13683..9dcccb158a 100644 --- a/python-client/trustedanalytics/core/frame.py +++ b/python-client/trustedanalytics/core/frame.py @@ -1176,7 +1176,7 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): Join them on the 'numbers' column ('inner' join by default) - >>> j = codes.join(colors, ['numbers']) + >>> j = codes.join(colors, 'numbers') >>> j.inspect() @@ -1193,7 +1193,7 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): Try a 'left' join, which includes all the rows of the codes frame. - >>> j_left = codes.join(colors, ['numbers'], how='left') + >>> j_left = codes.join(colors, 'numbers', how='left') >>> j_left.inspect() @@ -1211,7 +1211,7 @@ def __join(self, right, left_on, right_on=None, how='inner', name=None): And an outer join: - >>> j_outer = codes.join(colors, ['numbers'], how='outer') + >>> j_outer = codes.join(colors, 'numbers', how='outer') >>> j_outer.inspect() diff --git a/python-client/trustedanalytics/rest/frame.py b/python-client/trustedanalytics/rest/frame.py index a9f4d2f84b..19220281c7 100644 --- a/python-client/trustedanalytics/rest/frame.py +++ b/python-client/trustedanalytics/rest/frame.py @@ -470,8 +470,16 @@ def inspect(self, frame, n, offset, selected_columns, format_settings): return RowsInspection(data, schema, offset=offset, format_settings=format_settings) def join(self, left, right, left_on, right_on, how, name=None): + if left_on is None: + raise ValueError("Please provide column name on which join should be performed") + elif isinstance(left_on, basestring): + left_on = [left_on] if right_on is None: right_on = left_on + elif isinstance(right_on, basestring): + right_on = [right_on] + if len(left_on) != len(right_on): + raise ValueError("Please provide equal number of join columns") arguments = {"name": name, "how": how, "left_frame": {"frame": left.uri, "join_columns": left_on}, From 29866baf79b24a6c5829637f4acdb554877797f7 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Wed, 13 Apr 2016 11:49:04 -0700 Subject: [PATCH 13/15] removed comment --- .../trustedanalytics/atk/engine/frame/PythonRddStorage.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala index baeb6117bd..4ca27d03bf 100644 --- a/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala +++ b/engine/engine-core/src/main/scala/org/trustedanalytics/atk/engine/frame/PythonRddStorage.scala @@ -91,8 +91,6 @@ object PythonRddStorage { def aggregateMapWith(data: FrameRdd, aggregateByColumnKeys: List[String], udf: Udf, udfSchema: Schema, sc: SparkContext): FrameRdd = { //Create a new schema which includes keys (KeyedSchema). val keyedSchema = udfSchema.copy(columns = data.frameSchema.columns(aggregateByColumnKeys) ++ udfSchema.columns) - //track key indices to fetch data during BSON decode. - //val keyIndices = for (key <- aggregateByColumnKeys) yield data.frameSchema.columnIndex(key) val converter = DataTypes.parseMany(keyedSchema.columns.map(_.dataType).toArray)(_) val groupRDD = data.groupByRows(row => row.values(aggregateByColumnKeys)) val pyRdd = aggregateRddToPyRdd(udf, groupRDD, sc) From 749f5fad1285163f382ffeefc1a41ff859562bb7 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Mon, 18 Apr 2016 17:30:55 -0700 Subject: [PATCH 14/15] removed _L and _R. Retrieved column names based on index --- .../atk/engine/frame/plugins/join/JoinRddFunctions.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala index 1eaeabe695..83a4df759f 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala @@ -272,7 +272,9 @@ object JoinRddFunctions extends Serializable { val rightSchema = right.frame.frameSchema val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) val frameRdd = new FrameRdd(newSchema, joinedRdd) - val leftColNames = left.joinColumns.map(col => col + "_L") + val leftColIndices = leftSchema.columnIndices(left.joinColumns) + val leftColNames = leftColIndices.map(colindex => newSchema.column(colindex).name) + //val leftColNames = left.joinColumns.map(col => col + "_L") frameRdd.dropColumns(leftColNames.toList) } @@ -303,7 +305,9 @@ object JoinRddFunctions extends Serializable { else { val newSchema = FrameSchema(Schema.join(leftSchema.columns, rightSchema.columns)) val frameRdd = new FrameRdd(newSchema, joinedRdd) - val rightColNames = right.joinColumns.map(col => col + "_R") + val rightColIndices = rightSchema.columnIndices(right.joinColumns).map(rightindex => leftSchema.columns.size + rightindex) + val rightColNames = rightColIndices.map(colindex => newSchema.column(colindex).name) + //val rightColNames = right.joinColumns.map(col => col + "_R") frameRdd.dropColumns(rightColNames.toList) } } From 0ffbb4591a308079780cdce6eb8ff24cbc999a91 Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Tue, 19 Apr 2016 09:18:45 -0700 Subject: [PATCH 15/15] Removed comments --- .../atk/engine/frame/plugins/join/JoinRddFunctions.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala index 83a4df759f..6ef485bd2d 100644 --- a/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala +++ b/engine-plugins/frame-plugins/src/main/scala/org/trustedanalytics/atk/engine/frame/plugins/join/JoinRddFunctions.scala @@ -274,7 +274,6 @@ object JoinRddFunctions extends Serializable { val frameRdd = new FrameRdd(newSchema, joinedRdd) val leftColIndices = leftSchema.columnIndices(left.joinColumns) val leftColNames = leftColIndices.map(colindex => newSchema.column(colindex).name) - //val leftColNames = left.joinColumns.map(col => col + "_L") frameRdd.dropColumns(leftColNames.toList) } @@ -307,7 +306,6 @@ object JoinRddFunctions extends Serializable { val frameRdd = new FrameRdd(newSchema, joinedRdd) val rightColIndices = rightSchema.columnIndices(right.joinColumns).map(rightindex => leftSchema.columns.size + rightindex) val rightColNames = rightColIndices.map(colindex => newSchema.column(colindex).name) - //val rightColNames = right.joinColumns.map(col => col + "_R") frameRdd.dropColumns(rightColNames.toList) } }