Apache Spark Tutorial
This tutorial covers Apache Spark with Python SDK (PySpark) for distributed data processing and machine learning. Always verify the version you install for compatibility.
1. What is Apache Spark?
Apache Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Python, Java, Scala, and R, along with an optimized engine that supports general execution graphs for data analysis.
Core capabilities:
| Feature | Description |
|---|---|
| Speed | In-memory computing up to 100x faster than Hadoop MapReduce |
| Ease of Use | High-level APIs and 80+ operators for transformations and actions |
| Unified Platform | SQL queries, streaming data, machine learning, and graph processing |
| Multiple Languages | Python (PySpark), Scala, Java, R, and SQL |
| Rich Ecosystem | Spark SQL, Spark Streaming, MLlib (ML), GraphX (graphs) |
| Cluster Computing | Run on Hadoop, Kubernetes, Mesos, standalone, or cloud |
2. Architecture Overview
┌─────────────────────────────────────────────────────────────────┐
│ Spark Architecture │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Driver Program (SparkContext) │
│ ↓ │
│ Cluster Manager │
│ (Standalone/YARN/Mesos/K8s) │
│ ↓ │
│ ┌─────────┴─────────┐ │
│ ↓ ↓ │
│ Worker Node Worker Node │
│ ├─ Executor ├─ Executor │
│ │ ├─ Cache │ ├─ Cache │
│ │ └─ Tasks │ └─ Tasks │
│ └─ Executor └─ Executor │
│ ├─ Cache ├─ Cache │
│ └─ Tasks └─ Tasks │
│ │
└─────────────────────────────────────────────────────────────────┘
Key components:
- Driver Program: Your main program that creates SparkContext and coordinates execution
- Cluster Manager: Allocates resources across applications (YARN, Mesos, Kubernetes, Standalone)
- Worker Nodes: Machines that run Spark executors
- Executors: Processes that run computations and store data for your application
- Tasks: Units of work sent to executors
3. Installation & Setup
Install PySpark
- pip
- conda
# Create virtual environment
python -m venv spark_env
source spark_env/bin/activate # On Windows: spark_env\Scripts\activate
# Install PySpark
pip install pyspark
# Optional: Install with specific version
pip install pyspark==3.5.0
# With additional dependencies for pandas integration
pip install pyspark pandas pyarrow
# Create conda environment
conda create -n spark_env python=3.11
conda activate spark_env
# Install PySpark
conda install -c conda-forge pyspark
# With additional dependencies
conda install -c conda-forge pyspark pandas pyarrow
Verify Installation
import pyspark
print(f"PySpark version: {pyspark.__version__}")
Set Up Java (Required)
Spark requires Java 8 or later:
# Check Java version
java -version
# If Java not installed:
# Ubuntu/Debian
sudo apt-get update
sudo apt-get install openjdk-11-jdk
# macOS (using Homebrew)
brew install openjdk@11
# Set JAVA_HOME
export JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64 # Adjust path as needed
4. Getting Started with PySpark
Create SparkSession
The entry point to Spark functionality is SparkSession:
from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder \
.appName("MySparkApp") \
.master("local[*]") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# Get SparkContext from SparkSession
sc = spark.sparkContext
print(f"Spark Version: {spark.version}")
print(f"Spark Master: {spark.sparkContext.master}")
Configuration options:
local[*]: Run locally with as many worker threads as coreslocal[4]: Run locally with 4 worker threadsspark://HOST:PORT: Connect to Spark standalone clusteryarn: Connect to YARN cluster
Stop SparkSession
# Always stop SparkSession when done
spark.stop()
5. RDD (Resilient Distributed Dataset)
RDDs are the fundamental data structure of Spark - immutable distributed collections of objects.
Creating RDDs
# From Python collection
data = [1, 2, 3, 4, 5]
rdd = sc.parallelize(data)
# From text file
text_rdd = sc.textFile("path/to/file.txt")
# From multiple files
multi_rdd = sc.textFile("path/to/*.txt")
RDD Transformations
Transformations are lazy operations that return a new RDD:
# Map: Apply function to each element
rdd = sc.parallelize([1, 2, 3, 4, 5])
squared_rdd = rdd.map(lambda x: x * x)
# Result: [1, 4, 9, 16, 25]
# Filter: Keep elements that satisfy condition
even_rdd = rdd.filter(lambda x: x % 2 == 0)
# Result: [2, 4]
# FlatMap: Map then flatten
text = sc.parallelize(["hello world", "apache spark"])
words_rdd = text.flatMap(lambda line: line.split(" "))
# Result: ["hello", "world", "apache", "spark"]
# Distinct: Remove duplicates
duplicates = sc.parallelize([1, 2, 2, 3, 3, 3])
unique_rdd = duplicates.distinct()
# Result: [1, 2, 3]
# Union: Combine two RDDs
rdd1 = sc.parallelize([1, 2, 3])
rdd2 = sc.parallelize([3, 4, 5])
union_rdd = rdd1.union(rdd2)
# Result: [1, 2, 3, 3, 4, 5]
RDD Actions
Actions trigger computation and return results:
rdd = sc.parallelize([1, 2, 3, 4, 5])
# Collect: Return all elements as array
print(rdd.collect()) # [1, 2, 3, 4, 5]
# Count: Number of elements
print(rdd.count()) # 5
# First: Get first element
print(rdd.first()) # 1
# Take: Get first n elements
print(rdd.take(3)) # [1, 2, 3]
# Reduce: Aggregate using function
sum_result = rdd.reduce(lambda a, b: a + b)
print(sum_result) # 15
# foreach: Apply function to each element (no return)
rdd.foreach(lambda x: print(x))
Key-Value RDD Operations
# Create key-value RDD
pairs = sc.parallelize([("a", 1), ("b", 2), ("a", 3), ("b", 4)])
# reduceByKey: Combine values for each key
sum_by_key = pairs.reduceByKey(lambda a, b: a + b)
print(sum_by_key.collect()) # [('a', 4), ('b', 6)]
# groupByKey: Group values for each key
grouped = pairs.groupByKey()
print([(k, list(v)) for k, v in grouped.collect()])
# [('a', [1, 3]), ('b', [2, 4])]
# sortByKey: Sort by keys
sorted_pairs = pairs.sortByKey()
# mapValues: Apply function to values only
incremented = pairs.mapValues(lambda x: x + 10)
print(incremented.collect()) # [('a', 11), ('b', 12), ('a', 13), ('b', 14)]
# Join operations
rdd1 = sc.parallelize([("a", 1), ("b", 2)])
rdd2 = sc.parallelize([("a", "x"), ("b", "y")])
joined = rdd1.join(rdd2)
print(joined.collect()) # [('a', (1, 'x')), ('b', (2, 'y'))]
6. DataFrames and Spark SQL
DataFrames are distributed collections with named columns (like SQL tables or pandas DataFrames).
Creating DataFrames
from pyspark.sql import Row
# From Python list of tuples
data = [("Alice", 25), ("Bob", 30), ("Charlie", 35)]
df = spark.createDataFrame(data, ["name", "age"])
# From list of Row objects
data = [
Row(name="Alice", age=25, city="NYC"),
Row(name="Bob", age=30, city="LA"),
Row(name="Charlie", age=35, city="Chicago")
]
df = spark.createDataFrame(data)
# From pandas DataFrame
import pandas as pd
pandas_df = pd.DataFrame({
"name": ["Alice", "Bob", "Charlie"],
"age": [25, 30, 35]
})
df = spark.createDataFrame(pandas_df)
# From CSV file
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)
# From JSON file
df = spark.read.json("path/to/file.json")
# From Parquet file
df = spark.read.parquet("path/to/file.parquet")
DataFrame Operations
# Show first rows
df.show()
df.show(5) # Show first 5 rows
df.show(5, truncate=False) # Don't truncate long strings
# Schema and structure
df.printSchema()
df.columns # List column names
df.dtypes # List column names and types
# Basic statistics
df.describe().show()
df.count()
# Select columns
df.select("name").show()
df.select("name", "age").show()
df.select(df["name"], df["age"] + 1).show()
# Filter rows
df.filter(df["age"] > 25).show()
df.filter((df["age"] > 25) & (df["name"] == "Bob")).show()
df.where(df["age"] > 25).show() # Alias for filter
# Add new column
from pyspark.sql.functions import col
df = df.withColumn("age_plus_10", col("age") + 10)
# Rename column
df = df.withColumnRenamed("age", "person_age")
# Drop column
df = df.drop("age_plus_10")
# Sort
df.orderBy("age").show()
df.orderBy(df["age"].desc()).show()
df.sort(df["age"].asc()).show()
# Group by and aggregate
df.groupBy("city").count().show()
df.groupBy("city").avg("age").show()
df.groupBy("city").agg({"age": "max", "name": "count"}).show()
Spark SQL
# Register DataFrame as temporary view
df.createOrReplaceTempView("people")
# Run SQL queries
result = spark.sql("""
SELECT name, age
FROM people
WHERE age > 25
ORDER BY age DESC
""")
result.show()
# Complex SQL example
result = spark.sql("""
SELECT
city,
COUNT(*) as count,
AVG(age) as avg_age,
MAX(age) as max_age
FROM people
GROUP BY city
HAVING COUNT(*) > 1
""")
result.show()
Built-in Functions
from pyspark.sql.functions import *
# String functions
df = df.withColumn("name_upper", upper(col("name")))
df = df.withColumn("name_lower", lower(col("name")))
df = df.withColumn("name_length", length(col("name")))
# Numeric functions
df = df.withColumn("age_abs", abs(col("age")))
df = df.withColumn("age_rounded", round(col("age"), 0))
# Date functions
from pyspark.sql.functions import current_date, current_timestamp, date_add
df = df.withColumn("current_date", current_date())
df = df.withColumn("next_week", date_add(current_date(), 7))
# Conditional functions
df = df.withColumn(
"age_category",
when(col("age") < 30, "Young")
.when(col("age") < 50, "Middle")
.otherwise("Senior")
)
# Aggregation functions
df.select(
avg("age"),
max("age"),
min("age"),
sum("age"),
count("*")
).show()
7. Working with Different Data Formats
CSV
# Read CSV
df = spark.read \
.option("header", "true") \
.option("inferSchema", "true") \
.option("sep", ",") \
.csv("path/to/file.csv")
# Write CSV
df.write \
.option("header", "true") \
.mode("overwrite") \
.csv("path/to/output")
JSON
# Read JSON
df = spark.read.json("path/to/file.json")
# Read multi-line JSON
df = spark.read.option("multiLine", "true").json("path/to/file.json")
# Write JSON
df.write.mode("overwrite").json("path/to/output")
Parquet (Recommended for Big Data)
# Read Parquet
df = spark.read.parquet("path/to/file.parquet")
# Write Parquet with compression
df.write \
.mode("overwrite") \
.option("compression", "snappy") \
.parquet("path/to/output")
Database (JDBC)
# Read from database
df = spark.read \
.format("jdbc") \
.option("url", "jdbc:postgresql://localhost:5432/mydb") \
.option("dbtable", "mytable") \
.option("user", "username") \
.option("password", "password") \
.option("driver", "org.postgresql.Driver") \
.load()
# Write to database
df.write \
.format("jdbc") \
.option("url", "jdbc:postgresql://localhost:5432/mydb") \
.option("dbtable", "mytable") \
.option("user", "username") \
.option("password", "password") \
.option("driver", "org.postgresql.Driver") \
.mode("overwrite") \
.save()
8. Performance Optimization
Caching and Persistence
# Cache DataFrame in memory
df.cache()
df.persist() # Same as cache()
# Custom persistence levels
from pyspark import StorageLevel
df.persist(StorageLevel.MEMORY_AND_DISK)
df.persist(StorageLevel.DISK_ONLY)
# Unpersist to free memory
df.unpersist()
Partitioning
# Check current partitions
print(df.rdd.getNumPartitions())
# Repartition (increases or decreases partitions)
df_repartitioned = df.repartition(10)
df_repartitioned = df.repartition("city") # Partition by column
# Coalesce (only decreases partitions, more efficient)
df_coalesced = df.coalesce(2)
Broadcast Variables
For small datasets that need to be shared across all nodes:
# Broadcast small lookup table
lookup_dict = {"A": 1, "B": 2, "C": 3}
broadcast_var = sc.broadcast(lookup_dict)
# Use in transformations
def map_function(value):
return broadcast_var.value.get(value, 0)
rdd = sc.parallelize(["A", "B", "D"])
result = rdd.map(map_function)
print(result.collect()) # [1, 2, 0]
9. Machine Learning with MLlib
Basic ML Pipeline Example
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
# Prepare data
data = spark.createDataFrame([
(1.0, 2.0, 3.0),
(2.0, 3.0, 5.0),
(3.0, 4.0, 7.0),
(4.0, 5.0, 9.0),
], ["feature1", "feature2", "label"])
# Assemble features
assembler = VectorAssembler(
inputCols=["feature1", "feature2"],
outputCol="features"
)
data = assembler.transform(data)
# Split data
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)
# Create and train model
lr = LinearRegression(featuresCol="features", labelCol="label")
model = lr.fit(train_data)
# Make predictions
predictions = model.transform(test_data)
predictions.select("features", "label", "prediction").show()
# Evaluate model
evaluator = RegressionEvaluator(labelCol="label", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error: {rmse}")
# Model summary
print(f"Coefficients: {model.coefficients}")
print(f"Intercept: {model.intercept}")
Classification Example
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
# Create classifier
lr_classifier = LogisticRegression(
featuresCol="features",
labelCol="label",
maxIter=10
)
# Train model
model = lr_classifier.fit(train_data)
# Predictions
predictions = model.transform(test_data)
# Evaluate
evaluator = BinaryClassificationEvaluator(labelCol="label")
auc = evaluator.evaluate(predictions)
print(f"AUC: {auc}")
10. Common Use Cases
Word Count Example
# Classic word count
text_file = sc.textFile("path/to/text.txt")
counts = text_file \
.flatMap(lambda line: line.split(" ")) \
.map(lambda word: (word, 1)) \
.reduceByKey(lambda a, b: a + b)
# Get top 10 words
top_words = counts.takeOrdered(10, key=lambda x: -x[1])
print(top_words)
Data Processing Pipeline
# Read data
df = spark.read.csv("input.csv", header=True, inferSchema=True)
# Clean data
df = df.dropna() # Remove null values
df = df.dropDuplicates() # Remove duplicates
# Transform data
from pyspark.sql.functions import col, when
df = df.withColumn(
"age_group",
when(col("age") < 18, "Minor")
.when(col("age") < 65, "Adult")
.otherwise("Senior")
)
# Aggregate
result = df.groupBy("age_group") \
.agg(
count("*").alias("count"),
avg("salary").alias("avg_salary")
) \
.orderBy("age_group")
# Save results
result.write.parquet("output.parquet", mode="overwrite")
Join Operations
# Create sample DataFrames
customers = spark.createDataFrame([
(1, "Alice", "NYC"),
(2, "Bob", "LA"),
(3, "Charlie", "Chicago")
], ["id", "name", "city"])
orders = spark.createDataFrame([
(1, 1, 100),
(2, 1, 200),
(3, 2, 150),
(4, 4, 300) # Customer 4 doesn't exist
], ["order_id", "customer_id", "amount"])
# Inner join
inner_joined = customers.join(
orders,
customers.id == orders.customer_id,
"inner"
)
# Left join
left_joined = customers.join(
orders,
customers.id == orders.customer_id,
"left"
)
# Aggregate after join
result = inner_joined.groupBy("name") \
.agg(
count("order_id").alias("order_count"),
sum("amount").alias("total_amount")
)
result.show()
11. Best Practices
Performance Tips
- Use DataFrames over RDDs: DataFrames have optimizations from Catalyst optimizer
- Cache wisely: Cache data that's used multiple times
- Partition appropriately: More partitions for large data, fewer for small data
- Use Parquet format: Better compression and columnar storage
- Filter early: Apply filters before joins and aggregations
- Avoid UDFs when possible: Use built-in functions (faster)
- Broadcast small datasets: For join optimization
Code Organization
# Good practice: Reusable function
def read_and_clean_data(file_path):
df = spark.read.csv(file_path, header=True, inferSchema=True)
df = df.dropna()
df = df.dropDuplicates()
return df
# Good practice: Configuration
config = {
"input_path": "input.csv",
"output_path": "output.parquet",
"partition_count": 10
}
# Main processing
df = read_and_clean_data(config["input_path"])
df = df.repartition(config["partition_count"])
df.write.parquet(config["output_path"], mode="overwrite")
Error Handling
try:
df = spark.read.csv("input.csv", header=True, inferSchema=True)
result = df.filter(col("age") > 18).count()
print(f"Count: {result}")
except Exception as e:
print(f"Error: {str(e)}")
finally:
spark.stop()
12. Common Pitfalls to Avoid
- Not stopping SparkSession: Always call
spark.stop()when done - Over-partitioning small data: Creates overhead
- Collecting large datasets:
collect()brings all data to driver - usetake()instead - Unnecessary shuffles: Minimize operations like
groupByKey(), preferreduceByKey() - Not caching iterative algorithms: Cache intermediate results in ML pipelines
- Using wrong file format: Use Parquet for large datasets, not CSV
13. Monitoring and Debugging
Spark UI
Access Spark UI at http://localhost:4040 when running locally:
- Jobs: View job status and stages
- Stages: Detailed task information
- Storage: Cached RDDs/DataFrames
- Environment: Configuration settings
- Executors: Resource usage
Logging
# Set log level
spark.sparkContext.setLogLevel("WARN") # Options: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
# Custom logging
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.info("Processing started")
14. Complete Example: ETL Pipeline
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, count, sum, when
# Initialize Spark
spark = SparkSession.builder \
.appName("ETL Pipeline") \
.master("local[*]") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
try:
# Extract
sales_df = spark.read.csv("sales.csv", header=True, inferSchema=True)
products_df = spark.read.csv("products.csv", header=True, inferSchema=True)
# Transform
# Clean data
sales_df = sales_df.dropna()
products_df = products_df.dropna()
# Join datasets
joined_df = sales_df.join(
products_df,
sales_df.product_id == products_df.id,
"inner"
)
# Add calculated columns
joined_df = joined_df.withColumn(
"revenue",
col("quantity") * col("price")
)
joined_df = joined_df.withColumn(
"category_tier",
when(col("price") < 10, "Budget")
.when(col("price") < 50, "Standard")
.otherwise("Premium")
)
# Aggregate
summary = joined_df.groupBy("category", "category_tier") \
.agg(
count("*").alias("total_sales"),
sum("revenue").alias("total_revenue"),
avg("revenue").alias("avg_revenue")
) \
.orderBy("total_revenue", ascending=False)
# Load
summary.write \
.mode("overwrite") \
.parquet("output/sales_summary.parquet")
# Display results
summary.show(20, truncate=False)
print("ETL Pipeline completed successfully!")
except Exception as e:
print(f"Error in ETL pipeline: {str(e)}")
raise
finally:
spark.stop()
15. Next Steps
- Spark Streaming: Real-time data processing
- Structured Streaming: Stream processing with DataFrame API
- GraphX: Graph processing and analytics
- Advanced MLlib: Deep learning integration, hyperparameter tuning
- Cluster Deployment: YARN, Kubernetes, or cloud platforms (AWS EMR, Azure HDInsight, Databricks)
- Delta Lake: ACID transactions on data lakes
- Performance Tuning: Advanced optimization techniques