Training Models
Deep dive into training machine learning models with PostgresML, including advanced techniques and optimization strategies.
Training Function Overview
The pgml.train() function is the core of PostgresML:
SELECT * FROM pgml.train(
project_name TEXT, -- Project identifier
task TEXT DEFAULT NULL, -- 'classification', 'regression', 'clustering'
relation_name TEXT DEFAULT NULL, -- Table or query
y_column_name TEXT DEFAULT NULL, -- Target column
algorithm TEXT DEFAULT 'linear', -- Algorithm to use
hyperparams JSONB DEFAULT '{}', -- Algorithm parameters
search TEXT DEFAULT NULL, -- 'grid' or 'random'
search_params JSONB DEFAULT '{}', -- Search parameters
search_args JSONB DEFAULT '{}', -- Additional search args
test_size FLOAT DEFAULT 0.25, -- Test set size
test_sampling TEXT DEFAULT 'random' -- Sampling strategy
);
Classification Tasks
Binary Classification
-- Create sample data
CREATE TABLE emails (
id SERIAL PRIMARY KEY,
subject TEXT,
body TEXT,
sender TEXT,
has_attachments BOOLEAN,
word_count INT,
is_spam BOOLEAN -- Target variable
);
-- Train binary classifier
SELECT * FROM pgml.train(
project_name => 'spam_detector',
task => 'classification',
relation_name => 'emails',
y_column_name => 'is_spam',
algorithm => 'xgboost',
hyperparams => '{
"n_estimators": 100,
"max_depth": 7,
"learning_rate": 0.1,
"objective": "binary:logistic"
}'
);
Multi-class Classification
-- Create multi-class dataset
CREATE TABLE documents (
id SERIAL PRIMARY KEY,
title TEXT,
content TEXT,
author TEXT,
word_count INT,
category TEXT -- Multiple categories: 'tech', 'sports', 'politics', etc.
);
-- Train multi-class classifier
SELECT * FROM pgml.train(
project_name => 'document_classifier',
task => 'classification',
relation_name => 'documents',
y_column_name => 'category',
algorithm => 'random_forest',
hyperparams => '{
"n_estimators": 200,
"max_depth": 15,
"min_samples_split": 5,
"class_weight": "balanced"
}'
);
Imbalanced Classes
-- Handle imbalanced datasets
SELECT * FROM pgml.train(
project_name => 'fraud_detector',
task => 'classification',
relation_name => 'transactions',
y_column_name => 'is_fraud',
algorithm => 'xgboost',
hyperparams => '{
"n_estimators": 150,
"scale_pos_weight": 10, -- Balance for rare positive cases
"max_depth": 8,
"learning_rate": 0.05
}'
);
-- Or use class_weight for scikit-learn algorithms
SELECT * FROM pgml.train(
project_name => 'fraud_detector_v2',
task => 'classification',
relation_name => 'transactions',
y_column_name => 'is_fraud',
algorithm => 'random_forest',
hyperparams => '{
"n_estimators": 200,
"class_weight": "balanced",
"max_depth": 12
}'
);
Regression Tasks
Linear Regression
-- Simple linear regression
CREATE TABLE housing (
id SERIAL PRIMARY KEY,
square_feet FLOAT,
bedrooms INT,
bathrooms FLOAT,
year_built INT,
price FLOAT -- Target
);
SELECT * FROM pgml.train(
project_name => 'house_prices',
task => 'regression',
relation_name => 'housing',
y_column_name => 'price',
algorithm => 'linear_regression'
);
Advanced Regression
-- Gradient boosting for regression
SELECT * FROM pgml.train(
project_name => 'advanced_house_prices',
task => 'regression',
relation_name => 'housing',
y_column_name => 'price',
algorithm => 'xgboost',
hyperparams => '{
"n_estimators": 200,
"max_depth": 8,
"learning_rate": 0.05,
"min_child_weight": 3,
"subsample": 0.8,
"colsample_bytree": 0.8,
"objective": "reg:squarederror"
}'
);
-- LightGBM for large datasets
SELECT * FROM pgml.train(
project_name => 'fast_house_prices',
task => 'regression',
relation_name => 'housing',
y_column_name => 'price',
algorithm => 'lightgbm',
hyperparams => '{
"n_estimators": 300,
"num_leaves": 31,
"learning_rate": 0.05,
"feature_fraction": 0.8
}'
);
Time Series Forecasting
-- Prepare time series data
CREATE TABLE sensor_readings (
timestamp TIMESTAMPTZ,
sensor_id INT,
temperature FLOAT,
humidity FLOAT,
pressure FLOAT,
next_hour_temp FLOAT -- Target: temperature in next hour
);
-- Train forecasting model
SELECT * FROM pgml.train(
project_name => 'temperature_forecast',
task => 'regression',
relation_name => 'sensor_readings',
y_column_name => 'next_hour_temp',
algorithm => 'xgboost',
hyperparams => '{
"n_estimators": 250,
"max_depth": 6,
"learning_rate": 0.1
}'
);
Clustering Tasks
K-Means Clustering
-- Customer segmentation
CREATE TABLE customers (
id SERIAL PRIMARY KEY,
age INT,
annual_income FLOAT,
spending_score INT,
purchase_frequency INT
);
SELECT * FROM pgml.train(
project_name => 'customer_segments',
task => 'clustering',
relation_name => 'customers',
algorithm => 'kmeans',
hyperparams => '{
"n_clusters": 5,
"init": "k-means++",
"n_init": 10,
"max_iter": 300
}'
);
DBSCAN Clustering
-- Density-based clustering for anomaly detection
SELECT * FROM pgml.train(
project_name => 'anomaly_clusters',
task => 'clustering',
relation_name => 'network_traffic',
algorithm => 'dbscan',
hyperparams => '{
"eps": 0.5,
"min_samples": 5
}'
);
Algorithm Selection Guide
Decision Tree
Choose tree-based algorithms for:
- Non-linear relationships
- Mixed data types (categorical and numerical)
- Feature importance insights
-- Random Forest: Good all-rounder
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'random_forest',
hyperparams => '{"n_estimators": 100, "max_depth": 10}'
);
-- XGBoost: High performance, handles missing data
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'xgboost',
hyperparams => '{"n_estimators": 100, "learning_rate": 0.1}'
);
-- LightGBM: Very fast on large datasets
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'lightgbm',
hyperparams => '{"n_estimators": 100, "num_leaves": 31}'
);
Linear Models
Choose linear models for:
- Linear relationships
- High-dimensional data
- Interpretability
- Fast training and prediction
-- Logistic Regression
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'logistic_regression',
hyperparams => '{"max_iter": 1000}'
);
-- Ridge Regression (L2 regularization)
SELECT * FROM pgml.train(
'my_project', 'regression', 'data', 'target',
algorithm => 'ridge',
hyperparams => '{"alpha": 1.0}'
);
-- Lasso Regression (L1 regularization)
SELECT * FROM pgml.train(
'my_project', 'regression', 'data', 'target',
algorithm => 'lasso',
hyperparams => '{"alpha": 0.1}'
);
Neural Networks
-- Multi-layer Perceptron
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'mlp_classifier',
hyperparams => '{
"hidden_layer_sizes": [100, 50],
"activation": "relu",
"max_iter": 500,
"learning_rate": "adaptive"
}'
);
Hyperparameter Optimization
Grid Search
Exhaustively search through hyperparameter combinations:
SELECT * FROM pgml.train(
project_name => 'optimized_model',
task => 'classification',
relation_name => 'data',
y_column_name => 'label',
algorithm => 'xgboost',
search => 'grid',
search_params => '{
"n_estimators": [50, 100, 200],
"max_depth": [3, 5, 7],
"learning_rate": [0.01, 0.1, 0.3]
}',
search_args => '{
"cv": 5,
"scoring": "accuracy"
}'
);
Random Search
Sample random combinations (faster for large spaces):
SELECT * FROM pgml.train(
project_name => 'random_optimized',
task => 'regression',
relation_name => 'data',
y_column_name => 'target',
algorithm => 'random_forest',
search => 'random',
search_params => '{
"n_estimators": [50, 100, 150, 200],
"max_depth": [5, 10, 15, 20, 25],
"min_samples_split": [2, 5, 10]
}',
search_args => '{
"n_iter": 20,
"cv": 3,
"scoring": "neg_mean_squared_error"
}'
);
Cross-Validation
K-Fold Cross-Validation
-- Built into hyperparameter search
SELECT * FROM pgml.train(
project_name => 'cv_model',
task => 'classification',
relation_name => 'data',
y_column_name => 'label',
algorithm => 'random_forest',
search => 'grid',
search_params => '{
"n_estimators": [100, 200]
}',
search_args => '{
"cv": 5, -- 5-fold cross-validation
"scoring": "f1_weighted"
}'
);
Training with Complex Queries
Feature Engineering in SQL
-- Train with engineered features
SELECT * FROM pgml.train(
project_name => 'engineered_model',
task => 'classification',
relation_name => $$
SELECT
-- Original features
age,
income,
-- Engineered features
EXTRACT(YEAR FROM age(date_of_birth)) AS calculated_age,
income / NULLIF(household_size, 0) AS income_per_person,
CASE
WHEN age < 30 THEN 'young'
WHEN age < 50 THEN 'middle'
ELSE 'senior'
END AS age_group,
-- Aggregations
(SELECT AVG(purchase_amount)
FROM purchases p
WHERE p.customer_id = c.id) AS avg_purchase,
-- Target
is_premium_customer
FROM customers c
$$,
y_column_name => 'is_premium_customer',
algorithm => 'xgboost'
);
Joining Multiple Tables
SELECT * FROM pgml.train(
project_name => 'multi_table_model',
task => 'regression',
relation_name => $$
SELECT
u.age,
u.location,
COUNT(DISTINCT o.id) AS order_count,
SUM(o.total) AS lifetime_value,
AVG(r.rating) AS avg_rating,
MAX(o.created_at) AS last_order_date,
-- Target
u.predicted_churn_score
FROM users u
LEFT JOIN orders o ON u.id = o.user_id
LEFT JOIN reviews r ON u.id = r.user_id
GROUP BY u.id, u.age, u.location, u.predicted_churn_score
$$,
y_column_name => 'predicted_churn_score'
);
Model Versioning
PostgresML automatically versions models:
-- Train multiple versions
SELECT * FROM pgml.train('my_project', 'classification', 'data', 'label',
algorithm => 'random_forest');
SELECT * FROM pgml.train('my_project', 'classification', 'data', 'label',
algorithm => 'xgboost');
-- View all versions
SELECT
id,
created_at,
algorithm,
metrics->>'accuracy' AS accuracy
FROM pgml.models
WHERE project_name = 'my_project'
ORDER BY created_at DESC;
-- Deploy specific version
SELECT pgml.deploy(
project_name => 'my_project',
model_id => 42
);
Monitoring Training
Track Training Progress
-- View recent training jobs
SELECT
project_name,
algorithm,
status,
created_at,
training_time
FROM pgml.models
ORDER BY created_at DESC
LIMIT 10;
Compare Model Performance
-- Compare different algorithms
SELECT
algorithm,
AVG((metrics->>'accuracy')::float) AS avg_accuracy,
MIN((metrics->>'accuracy')::float) AS min_accuracy,
MAX((metrics->>'accuracy')::float) AS max_accuracy,
COUNT(*) AS num_models
FROM pgml.models
WHERE project_name = 'my_project'
GROUP BY algorithm
ORDER BY avg_accuracy DESC;
Best Practices
1. Start with a Baseline
-- Always train a simple baseline first
SELECT * FROM pgml.train(
'baseline_model',
'classification',
'data',
'label',
algorithm => 'logistic_regression'
);
2. Use Appropriate Metrics
-- For imbalanced classification, use F1 or AUC
SELECT * FROM pgml.train(
'balanced_model',
'classification',
'data',
'label',
algorithm => 'xgboost',
search => 'grid',
search_args => '{
"cv": 5,
"scoring": "f1_weighted"
}'
);
3. Validate with Test Data
-- Use appropriate test size
SELECT * FROM pgml.train(
'validated_model',
'classification',
'data',
'label',
test_size => 0.2, -- 20% for testing
test_sampling => 'stratified' -- Maintain class distribution
);
4. Regular Retraining
-- Retrain with fresh data periodically
SELECT * FROM pgml.train(
'production_model',
'classification',
'SELECT * FROM data WHERE created_at > NOW() - INTERVAL ''30 days''',
'label',
algorithm => 'xgboost'
);
Next Steps
Continue learning about:
- Making Predictions - Use trained models
- Transformers & NLP - Natural language processing
- Advanced Examples - Complex use cases