Best Practices
Learn production-ready patterns and best practices for using Databricks effectively, securely, and cost-efficiently.
Workspace Organization
Folder Structure
/Workspace
├── /Users
│ └── /your.email@company.com
│ └── /personal-experiments
├── /Shared
│ ├── /data-engineering
│ │ ├── /bronze-pipelines
│ │ ├── /silver-pipelines
│ │ └── /gold-pipelines
│ ├── /data-science
│ │ ├── /experiments
│ │ ├── /models
│ │ └── /feature-engineering
│ └── /analytics
│ ├── /dashboards
│ └── /reports
└── /Repos
├── /team-repo
└── /shared-libraries
Naming Conventions
# Notebooks
# Format: [domain]_[action]_[entity]_[version].py
# Examples:
# - etl_ingest_customer_data_v1.py
# - ml_train_churn_model_v2.py
# - analytics_daily_revenue_report.py
# Tables
# Format: [layer].[domain]_[entity]
# Examples:
# - bronze.raw_customer_events
# - silver.clean_customer_events
# - gold.customer_daily_metrics
# Clusters
# Format: [team]-[purpose]-[size]
# Examples:
# - data-eng-etl-large
# - data-sci-training-gpu
# - analytics-adhoc-small
Cluster Management
Cluster Configuration
# Development clusters
{
"cluster_name": "dev-exploratory",
"spark_version": "13.3.x-scala2.12",
"node_type_id": "m5.large",
"num_workers": 0, # Single node for development
"autotermination_minutes": 30,
"spark_conf": {
"spark.databricks.delta.preview.enabled": "true"
}
}
# Production ETL clusters
{
"cluster_name": "prod-etl",
"spark_version": "13.3.x-scala2.12",
"node_type_id": "m5.2xlarge",
"autoscale": {
"min_workers": 2,
"max_workers": 10
},
"autotermination_minutes": 60,
"spark_conf": {
"spark.sql.adaptive.enabled": "true",
"spark.sql.adaptive.coalescePartitions.enabled": "true"
}
}
# ML training clusters
{
"cluster_name": "ml-training",
"spark_version": "13.3.x-gpu-ml-scala2.12",
"node_type_id": "g4dn.xlarge",
"num_workers": 4,
"autotermination_minutes": 120,
"spark_conf": {
"spark.task.resource.gpu.amount": "1"
}
}
Cluster Policies
{
"cluster_type": {
"type": "fixed",
"value": "all-purpose"
},
"autotermination_minutes": {
"type": "range",
"minValue": 15,
"maxValue": 120,
"defaultValue": 30
},
"spark_version": {
"type": "regex",
"pattern": "13\\..*-scala.*",
"defaultValue": "13.3.x-scala2.12"
},
"node_type_id": {
"type": "allowlist",
"values": ["m5.large", "m5.xlarge", "m5.2xlarge"],
"defaultValue": "m5.large"
}
}
Cost Optimization
# Use job clusters instead of all-purpose clusters for scheduled jobs
# Job clusters automatically terminate after job completion
# Enable autoscaling
spark.conf.set("spark.databricks.cluster.profile", "singleNode") # For dev
spark.conf.set("spark.sql.adaptive.enabled", "true") # Auto optimize queries
# Use spot/preemptible instances for non-critical workloads
{
"aws_attributes": {
"availability": "SPOT",
"zone_id": "us-west-2a",
"spot_bid_price_percent": 100
}
}
# Monitor cluster usage
def get_cluster_costs():
"""Monitor cluster usage and costs."""
usage_df = spark.sql("""
SELECT
cluster_id,
cluster_name,
SUM(usage_quantity) as dbu_hours,
DATE(usage_date) as date
FROM system.billing.usage
WHERE usage_date >= current_date() - 30
GROUP BY cluster_id, cluster_name, date
ORDER BY dbu_hours DESC
""")
display(usage_df)
get_cluster_costs()
Data Management
Delta Lake Best Practices
# 1. Use liquid clustering for frequently filtered columns
spark.sql("""
CREATE TABLE IF NOT EXISTS gold.customer_metrics
USING DELTA
CLUSTER BY (date, region)
AS SELECT * FROM silver.clean_customers
""")
# 2. Enable auto-optimize
spark.sql("""
ALTER TABLE gold.customer_metrics
SET TBLPROPERTIES (
'delta.autoOptimize.optimizeWrite' = 'true',
'delta.autoOptimize.autoCompact' = 'true'
)
""")
# 3. Set appropriate retention period
spark.sql("""
ALTER TABLE gold.customer_metrics
SET TBLPROPERTIES (
'delta.deletedFileRetentionDuration' = 'interval 30 days'
)
""")
# 4. Regular maintenance
from delta.tables import DeltaTable
def maintain_delta_table(table_path, retention_hours=168):
"""Optimize and vacuum Delta table."""
delta_table = DeltaTable.forPath(spark, table_path)
# Optimize
delta_table.optimize().executeCompaction()
# Z-order (if applicable)
# delta_table.optimize().executeZOrderBy("date", "category")
# Vacuum
delta_table.vacuum(retention_hours)
print(f"Maintenance completed for {table_path}")
# Schedule regular maintenance
maintain_delta_table("/delta/gold/customer_metrics")
Schema Evolution
# Enable schema evolution
df.write.format("delta") \
.mode("append") \
.option("mergeSchema", "true") \
.saveAsTable("my_table")
# Schema validation before write
def validate_and_write(df, table_name, expected_schema):
"""Validate schema before writing."""
# Check if all expected columns exist
missing_cols = set(expected_schema.fieldNames()) - set(df.columns)
if missing_cols:
raise ValueError(f"Missing columns: {missing_cols}")
# Check data types
for field in expected_schema.fields:
if field.name in df.columns:
actual_type = df.schema[field.name].dataType
if actual_type != field.dataType:
print(f"Warning: Type mismatch for {field.name}")
# Write data
df.write.format("delta").mode("append").saveAsTable(table_name)
# Usage
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
expected_schema = StructType([
StructField("id", StringType(), False),
StructField("name", StringType(), False),
StructField("age", IntegerType(), True)
])
validate_and_write(df, "my_table", expected_schema)
Data Quality Framework
from pyspark.sql.functions import col, count, when
class DataQualityValidator:
"""Framework for data quality validation."""
def __init__(self, df, table_name):
self.df = df
self.table_name = table_name
self.checks = []
self.results = []
def check_not_null(self, column):
"""Check for null values."""
null_count = self.df.filter(col(column).isNull()).count()
self.results.append({
"check": f"{column}_not_null",
"passed": null_count == 0,
"null_count": null_count
})
return self
def check_unique(self, column):
"""Check for uniqueness."""
total_count = self.df.count()
distinct_count = self.df.select(column).distinct().count()
self.results.append({
"check": f"{column}_unique",
"passed": total_count == distinct_count,
"duplicates": total_count - distinct_count
})
return self
def check_range(self, column, min_val, max_val):
"""Check value range."""
out_of_range = self.df.filter(
(col(column) < min_val) | (col(column) > max_val)
).count()
self.results.append({
"check": f"{column}_range_{min_val}_to_{max_val}",
"passed": out_of_range == 0,
"out_of_range": out_of_range
})
return self
def check_pattern(self, column, pattern):
"""Check regex pattern."""
non_matching = self.df.filter(
~col(column).rlike(pattern)
).count()
self.results.append({
"check": f"{column}_pattern",
"passed": non_matching == 0,
"non_matching": non_matching
})
return self
def validate(self):
"""Run all checks and return results."""
results_df = spark.createDataFrame(self.results)
# Log results
print(f"\nData Quality Report for {self.table_name}")
display(results_df)
# Check if any failed
failed_checks = results_df.filter(col("passed") == False)
if failed_checks.count() > 0:
print("⚠️ Some quality checks failed!")
display(failed_checks)
return False
print("✅ All quality checks passed!")
return True
# Usage
validator = DataQualityValidator(df, "customer_table")
validator \
.check_not_null("customer_id") \
.check_unique("customer_id") \
.check_range("age", 0, 120) \
.check_pattern("email", r"^[\w\.-]+@[\w\.-]+\.\w+$") \
.validate()
Security Best Practices
Secret Management
# Never hardcode credentials!
# ❌ Bad
aws_access_key = "AKIAIOSFODNN7EXAMPLE"
aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
# ✅ Good - Use Databricks Secrets
aws_access_key = dbutils.secrets.get(scope="aws-credentials", key="access-key")
aws_secret_key = dbutils.secrets.get(scope="aws-credentials", key="secret-key")
# Set up secrets via CLI
# databricks secrets create-scope --scope aws-credentials
# databricks secrets put --scope aws-credentials --key access-key
# databricks secrets put --scope aws-credentials --key secret-key
Access Control
# Table-level access control with Unity Catalog
spark.sql("""
GRANT SELECT ON TABLE gold.customer_metrics
TO `analytics-team`
""")
spark.sql("""
GRANT MODIFY ON TABLE silver.clean_data
TO `data-engineering-team`
""")
# Row-level security
spark.sql("""
CREATE TABLE gold.sensitive_customer_data
(
customer_id STRING,
name STRING,
region STRING,
revenue DECIMAL(10,2)
)
TBLPROPERTIES (
'delta.enableRowTracking' = 'true'
)
""")
# Dynamic views for row-level filtering
spark.sql("""
CREATE VIEW regional_customer_data AS
SELECT * FROM gold.sensitive_customer_data
WHERE region = current_user_region()
""")
Data Masking
from pyspark.sql.functions import sha2, regexp_replace
# Mask sensitive data
def mask_sensitive_data(df):
"""Apply data masking to sensitive columns."""
return df \
.withColumn("email_masked", regexp_replace("email", "(?<=.{3}).(?=.*@)", "*")) \
.withColumn("ssn_hashed", sha2("ssn", 256)) \
.drop("email", "ssn") \
.withColumnRenamed("email_masked", "email") \
.withColumnRenamed("ssn_hashed", "ssn_hash")
# Usage
masked_df = mask_sensitive_data(df)
display(masked_df)
Performance Optimization
Query Optimization
# 1. Use filter early
# ❌ Bad
result = df.select("*").groupBy("category").count().filter("count > 100")
# ✅ Good
result = df.filter("date >= '2024-01-01'").groupBy("category").count()
# 2. Broadcast small tables
from pyspark.sql.functions import broadcast
# ✅ Good for joins with small dimension tables
large_fact.join(broadcast(small_dim), "key")
# 3. Avoid wide transformations when possible
# ❌ Bad - Forces shuffle
df.repartition(100)
# ✅ Good - No shuffle for reducing partitions
df.coalesce(10)
# 4. Cache strategically
# Cache DataFrames used multiple times
df_cached = df.filter(col("important_column").isNotNull()).cache()
result1 = df_cached.groupBy("col1").count()
result2 = df_cached.groupBy("col2").avg("value")
df_cached.unpersist()
# 5. Use Delta Lake optimizations
spark.sql("""
OPTIMIZE gold.customer_metrics
ZORDER BY (date, customer_id)
""")
Monitoring and Debugging
# Enable query debugging
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.logLevel", "INFO")
# Analyze query plans
df.explain(mode="formatted")
df.explain(mode="cost")
# Monitor Spark UI metrics
def analyze_query_performance(df):
"""Analyze and log query performance metrics."""
import time
start_time = time.time()
count = df.count()
execution_time = time.time() - start_time
# Get query plan
print("Physical Plan:")
df.explain(mode="formatted")
# Log metrics
print(f"\nPerformance Metrics:")
print(f"Row Count: {count:,}")
print(f"Execution Time: {execution_time:.2f}s")
print(f"Rows/Second: {count/execution_time:,.0f}")
return {
"count": count,
"execution_time": execution_time,
"throughput": count/execution_time
}
# Usage
metrics = analyze_query_performance(df)
CI/CD and Version Control
Git Integration
# Directory structure for Databricks Repos
"""
/Repos
├── .gitignore
├── README.md
├── /notebooks
│ ├── /etl
│ │ ├── 01_bronze_ingestion.py
│ │ ├── 02_silver_transformation.py
│ │ └── 03_gold_aggregation.py
│ └── /ml
│ ├── train_model.py
│ └── evaluate_model.py
├── /src
│ ├── __init__.py
│ ├── /utils
│ │ ├── data_quality.py
│ │ └── spark_helpers.py
│ └── /models
│ └── feature_engineering.py
├── /tests
│ ├── test_data_quality.py
│ └── test_transformations.py
├── /config
│ ├── dev.yaml
│ ├── staging.yaml
│ └── prod.yaml
└── requirements.txt
"""
Testing Framework
# Unit tests for Spark transformations
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
@pytest.fixture(scope="session")
def spark():
return SparkSession.builder.appName("test").getOrCreate()
def test_data_transformation(spark):
"""Test data transformation logic."""
# Create test data
test_data = [
(1, "John", 25),
(2, "Jane", 30),
(3, "Bob", None)
]
df = spark.createDataFrame(test_data, ["id", "name", "age"])
# Apply transformation
result = df.filter(col("age").isNotNull())
# Assertions
assert result.count() == 2
assert result.filter(col("id") == 3).count() == 0
def test_aggregation(spark):
"""Test aggregation logic."""
test_data = [
("A", 100),
("A", 200),
("B", 150)
]
df = spark.createDataFrame(test_data, ["category", "value"])
result = df.groupBy("category").sum("value")
assert result.count() == 2
result_dict = {row["category"]: row["sum(value)"] for row in result.collect()}
assert result_dict["A"] == 300
assert result_dict["B"] == 150
# Integration tests
def test_end_to_end_pipeline(spark):
"""Test complete pipeline."""
# Bronze layer
bronze_df = spark.read.csv("/test-data/input.csv", header=True)
bronze_df.write.format("delta").mode("overwrite").saveAsTable("test_bronze")
# Silver layer
silver_df = spark.table("test_bronze").filter(col("id").isNotNull())
silver_df.write.format("delta").mode("overwrite").saveAsTable("test_silver")
# Gold layer
gold_df = spark.table("test_silver").groupBy("category").count()
gold_df.write.format("delta").mode("overwrite").saveAsTable("test_gold")
# Verify
assert spark.table("test_gold").count() > 0
Deployment Pipeline
# .github/workflows/databricks-ci-cd.yml
name: Databricks CI/CD
on:
push:
branches: [main, dev]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest
- name: Run tests
run: pytest tests/
deploy:
needs: test
if: github.ref == 'refs/heads/main'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Deploy to Databricks
run: |
databricks workspace import_dir \
./notebooks \
/Production/notebooks \
--overwrite
Logging and Monitoring
Structured Logging
import logging
import json
from datetime import datetime
class StructuredLogger:
"""Custom structured logger for Databricks."""
def __init__(self, name):
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.INFO)
def log(self, level, message, **kwargs):
"""Log structured message."""
log_entry = {
"timestamp": datetime.now().isoformat(),
"level": level,
"message": message,
"metadata": kwargs
}
self.logger.log(getattr(logging, level), json.dumps(log_entry))
def info(self, message, **kwargs):
self.log("INFO", message, **kwargs)
def error(self, message, **kwargs):
self.log("ERROR", message, **kwargs)
def warning(self, message, **kwargs):
self.log("WARNING", message, **kwargs)
# Usage
logger = StructuredLogger("etl_pipeline")
logger.info("Starting data processing",
table="customer_data",
row_count=10000)
try:
# Process data
pass
except Exception as e:
logger.error("Processing failed",
error=str(e),
table="customer_data")
Operational Metrics
def track_pipeline_metrics(func):
"""Decorator to track pipeline execution metrics."""
import time
import mlflow
def wrapper(*args, **kwargs):
start_time = time.time()
with mlflow.start_run(run_name=func.__name__):
try:
result = func(*args, **kwargs)
execution_time = time.time() - start_time
# Log metrics
mlflow.log_metric("execution_time", execution_time)
mlflow.log_metric("status", 1) # Success
if hasattr(result, 'count'):
row_count = result.count()
mlflow.log_metric("row_count", row_count)
return result
except Exception as e:
execution_time = time.time() - start_time
mlflow.log_metric("execution_time", execution_time)
mlflow.log_metric("status", 0) # Failure
mlflow.log_param("error", str(e))
raise
return wrapper
# Usage
@track_pipeline_metrics
def process_customer_data():
df = spark.table("bronze.raw_customers")
transformed = df.filter(col("email").isNotNull())
transformed.write.format("delta").mode("overwrite").saveAsTable("silver.customers")
return transformed
result = process_customer_data()
Documentation
Self-Documenting Code
"""
Customer Churn Prediction Pipeline
This notebook processes customer data and trains a churn prediction model.
Inputs:
- bronze.customer_events: Raw customer interaction events
- bronze.customer_profile: Customer demographic data
Outputs:
- silver.customer_features: Engineered features for ML
- models:/customer_churn_model: Trained ML model
Schedule: Daily at 2 AM UTC
Author: Data Science Team
Last Updated: 2024-01-15
"""
from pyspark.sql.functions import col, datediff, current_date
import mlflow
def engineer_features(events_df, profile_df):
"""
Engineer features from raw customer data.
Args:
events_df: Customer events DataFrame
profile_df: Customer profile DataFrame
Returns:
DataFrame with engineered features
Features created:
- days_since_last_purchase: Days since last transaction
- total_purchases: Total number of purchases
- avg_purchase_value: Average transaction value
"""
features = events_df.groupBy("customer_id").agg(
# Add aggregations
)
return features.join(profile_df, "customer_id")
# Execute pipeline
features_df = engineer_features(events_df, profile_df)