BAEL-3862: Spark differences DS, DF, RDD (#9976)

This commit is contained in:
Aaron Juarez 2020-10-02 12:46:48 -04:00 committed by GitHub
parent e795b87607
commit 67981e7cba
6 changed files with 2609 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,75 @@
package com.baeldung.differences.dataframe.dataset.rdd;
public class TouristData {
private String region;
private String country;
private String year;
private String series;
private Double value;
private String footnotes;
private String source;
public String getRegion() {
return region;
}
public void setRegion(String region) {
this.region = region;
}
public String getCountry() {
return country;
}
public void setCountry(String country) {
this.country = country;
}
public String getYear() {
return year;
}
public void setYear(String year) {
this.year = year;
}
public String getSeries() {
return series;
}
public void setSeries(String series) {
this.series = series;
}
public Double getValue() {
return value;
}
public void setValue(Double value) {
this.value = value;
}
public String getFootnotes() {
return footnotes;
}
public void setFootnotes(String footnotes) {
this.footnotes = footnotes;
}
public String getSource() {
return source;
}
public void setSource(String source) {
this.source = source;
}
@Override
public String toString() {
return "TouristData [region=" + region + ", country=" + country + ", year=" + year + ", series=" + series + ", value=" + value + ", footnotes=" + footnotes + ", source=" + source + "]";
}
}

View File

@ -0,0 +1,67 @@
package com.baeldung.differences.rdd;
import static org.junit.Assert.assertEquals;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import scala.Tuple2;
public class ActionsUnitTest {
private static JavaRDD<String> tourists;
private static JavaSparkContext sc;
public static final String COMMA_DELIMITER = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)";
@BeforeClass
public static void init() {
SparkConf conf = new SparkConf().setAppName("reduce")
.setMaster("local[*]");
sc = new JavaSparkContext(conf);
tourists = sc.textFile("data/Tourist.csv").filter(line -> !line.startsWith("Region"));
}
@AfterClass
public static void cleanup() {
sc.close();
}
@Test
public void whenDistinctCount_thenReturnDistinctNumRecords() {
JavaRDD<String> countries = tourists.map(line -> {
String[] columns = line.split(COMMA_DELIMITER);
return columns[1];
})
.distinct();
Long numberOfCountries = countries.count();
System.out.println("Count: " + numberOfCountries);
assertEquals(Long.valueOf(220), numberOfCountries);
}
@Test
public void whenReduceByKeySum_thenTotalValuePerKey() {
JavaRDD<String> touristsExpenditure = tourists.filter(line -> line.split(COMMA_DELIMITER)[3].contains("expenditure"));
JavaPairRDD<String, Double> expenditurePairRdd = touristsExpenditure.mapToPair(line -> {
String[] columns = line.split(COMMA_DELIMITER);
return new Tuple2<>(columns[1], Double.valueOf(columns[6]));
});
List<Tuple2<String, Double>> totalByCountry = expenditurePairRdd.reduceByKey((x, y) -> x + y)
.collect();
System.out.println("Total per Country: " + totalByCountry);
for(Tuple2<String, Double> tuple : totalByCountry) {
if (tuple._1.equals("Mexico")) {
assertEquals(Double.valueOf(99164), tuple._2);
}
}
}
}

View File

@ -0,0 +1,74 @@
package com.baeldung.differences.rdd;
import static org.apache.spark.sql.functions.col;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class DataFrameUnitTest {
private static SparkSession session;
private static Dataset<Row> data;
@BeforeClass
public static void init() {
session = SparkSession.builder()
.appName("TouristDataFrameExample")
.master("local[*]")
.getOrCreate();
DataFrameReader dataFrameReader = session.read();
data = dataFrameReader.option("header", "true")
.csv("data/Tourist.csv");
}
@AfterClass
public static void cleanup() {
session.stop();
}
@Test
public void whenSelectSpecificColumns_thenColumnsFiltered() {
Dataset<Row> selectedData = data.select(col("country"), col("year"), col("value"));
selectedData.show();
List<String> resultList = Arrays.asList(selectedData.columns());
assertTrue(resultList.contains("country"));
assertTrue(resultList.contains("year"));
assertTrue(resultList.contains("value"));
assertFalse(resultList.contains("Series"));
}
@Test
public void whenFilteringByCountry_thenCountryRecordsSelected() {
Dataset<Row> filteredData = data.filter(col("country").equalTo("Mexico"));
filteredData.show();
filteredData.foreach(record -> {
assertEquals("Mexico", record.get(1));
});
}
@Test
public void whenGroupCountByCountry_thenContryTotalRecords() {
Dataset<Row> recordsPerCountry = data.groupBy(col("country"))
.count();
recordsPerCountry.show();
Dataset<Row> filteredData = recordsPerCountry.filter(col("country").equalTo("Sweden"));
assertEquals(new Long(12), filteredData.first()
.get(1));
}
}

View File

@ -0,0 +1,83 @@
package com.baeldung.differences.rdd;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.sum;
import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import com.baeldung.differences.dataframe.dataset.rdd.TouristData;
public class DatasetUnitTest {
private static SparkSession session;
private static Dataset<TouristData> typedDataset;
@BeforeClass
public static void init() {
session = SparkSession.builder()
.appName("TouristDatasetExample")
.master("local[*]")
.getOrCreate();
DataFrameReader dataFrameReader = session.read();
Dataset<Row> data = dataFrameReader.option("header", "true")
.csv("data/Tourist.csv");
Dataset<Row> responseWithSelectedColumns = data.select(col("region"),
col("country"), col("year"), col("series"), col("value").cast("double"),
col("footnotes"), col("source"));
typedDataset = responseWithSelectedColumns.as(Encoders.bean(TouristData.class));
}
@AfterClass
public static void cleanup() {
session.stop();
}
@Test
public void whenFilteringByCountry_thenCountryRecordsSelected() {
Dataset<TouristData> selectedData = typedDataset
.filter((FilterFunction<TouristData>) record -> record.getCountry()
.equals("Norway"));
selectedData.show();
selectedData.foreach(record -> {
assertEquals("Norway", record.getCountry());
});
}
@Test
public void whenGroupCountByCountry_thenContryTotalRecords() {
Dataset<Row> countriesCount = typedDataset.groupBy(typedDataset.col("country"))
.count();
countriesCount.show();
assertEquals(Long.valueOf(220), Long.valueOf(countriesCount.count()));
}
@Test
public void whenFilteredByPropertyRange_thenRetreiveValidRecords() {
// Filter records with existing data for years between 2010 and 2017
typedDataset.filter((FilterFunction<TouristData>) record -> record.getYear() != null
&& (Long.valueOf(record.getYear()) > 2010 && Long.valueOf(record.getYear()) < 2017))
.show();
}
@Test
public void whenSumValue_thenRetreiveTotalValue() {
// Total tourist expenditure by country
typedDataset.filter((FilterFunction<TouristData>) record -> record.getValue() != null
&& record.getSeries()
.contains("expenditure"))
.groupBy("country")
.agg(sum("value"))
.show();
}
}

View File

@ -0,0 +1,63 @@
package com.baeldung.differences.rdd;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TransformationsUnitTest {
public static final String COMMA_DELIMITER = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)";
private static JavaSparkContext sc;
private static JavaRDD<String> tourists;
@BeforeClass
public static void init() {
SparkConf conf = new SparkConf().setAppName("uppercaseCountries")
.setMaster("local[*]");
sc = new JavaSparkContext(conf);
tourists = sc.textFile("data/Tourist.csv")
.filter(line -> !line.startsWith("Region")); //filter header row
}
@AfterClass
public static void cleanup() {
sc.close();
}
@Test
public void whenMapUpperCase_thenCountryNameUppercased() {
JavaRDD<String> upperCaseCountries = tourists.map(line -> {
String[] columns = line.split(COMMA_DELIMITER);
return columns[1].toUpperCase();
})
.distinct();
upperCaseCountries.saveAsTextFile("data/output/uppercase.txt");
upperCaseCountries.foreach(country -> {
//replace non alphanumerical characters
country = country.replaceAll("[^a-zA-Z]", "");
assertTrue(StringUtils.isAllUpperCase(country));
});
}
@Test
public void whenFilterByCountry_thenShowRequestedCountryRecords() {
JavaRDD<String> touristsInMexico = tourists.filter(line -> line.split(COMMA_DELIMITER)[1].equals("Mexico"));
touristsInMexico.saveAsTextFile("data/output/touristInMexico.txt");
touristsInMexico.foreach(record -> {
assertEquals("Mexico", record.split(COMMA_DELIMITER)[1]);
});
}
}