Skip to main content

Training Models

Deep dive into training machine learning models with PostgresML, including advanced techniques and optimization strategies.

Training Function Overview

The pgml.train() function is the core of PostgresML:

SELECT * FROM pgml.train(
project_name TEXT, -- Project identifier
task TEXT DEFAULT NULL, -- 'classification', 'regression', 'clustering'
relation_name TEXT DEFAULT NULL, -- Table or query
y_column_name TEXT DEFAULT NULL, -- Target column
algorithm TEXT DEFAULT 'linear', -- Algorithm to use
hyperparams JSONB DEFAULT '{}', -- Algorithm parameters
search TEXT DEFAULT NULL, -- 'grid' or 'random'
search_params JSONB DEFAULT '{}', -- Search parameters
search_args JSONB DEFAULT '{}', -- Additional search args
test_size FLOAT DEFAULT 0.25, -- Test set size
test_sampling TEXT DEFAULT 'random' -- Sampling strategy
);

Classification Tasks

Binary Classification

-- Create sample data
CREATE TABLE emails (
id SERIAL PRIMARY KEY,
subject TEXT,
body TEXT,
sender TEXT,
has_attachments BOOLEAN,
word_count INT,
is_spam BOOLEAN -- Target variable
);

-- Train binary classifier
SELECT * FROM pgml.train(
project_name => 'spam_detector',
task => 'classification',
relation_name => 'emails',
y_column_name => 'is_spam',
algorithm => 'xgboost',
hyperparams => '{
"n_estimators": 100,
"max_depth": 7,
"learning_rate": 0.1,
"objective": "binary:logistic"
}'
);

Multi-class Classification

-- Create multi-class dataset
CREATE TABLE documents (
id SERIAL PRIMARY KEY,
title TEXT,
content TEXT,
author TEXT,
word_count INT,
category TEXT -- Multiple categories: 'tech', 'sports', 'politics', etc.
);

-- Train multi-class classifier
SELECT * FROM pgml.train(
project_name => 'document_classifier',
task => 'classification',
relation_name => 'documents',
y_column_name => 'category',
algorithm => 'random_forest',
hyperparams => '{
"n_estimators": 200,
"max_depth": 15,
"min_samples_split": 5,
"class_weight": "balanced"
}'
);

Imbalanced Classes

-- Handle imbalanced datasets
SELECT * FROM pgml.train(
project_name => 'fraud_detector',
task => 'classification',
relation_name => 'transactions',
y_column_name => 'is_fraud',
algorithm => 'xgboost',
hyperparams => '{
"n_estimators": 150,
"scale_pos_weight": 10, -- Balance for rare positive cases
"max_depth": 8,
"learning_rate": 0.05
}'
);

-- Or use class_weight for scikit-learn algorithms
SELECT * FROM pgml.train(
project_name => 'fraud_detector_v2',
task => 'classification',
relation_name => 'transactions',
y_column_name => 'is_fraud',
algorithm => 'random_forest',
hyperparams => '{
"n_estimators": 200,
"class_weight": "balanced",
"max_depth": 12
}'
);

Regression Tasks

Linear Regression

-- Simple linear regression
CREATE TABLE housing (
id SERIAL PRIMARY KEY,
square_feet FLOAT,
bedrooms INT,
bathrooms FLOAT,
year_built INT,
price FLOAT -- Target
);

SELECT * FROM pgml.train(
project_name => 'house_prices',
task => 'regression',
relation_name => 'housing',
y_column_name => 'price',
algorithm => 'linear_regression'
);

Advanced Regression

-- Gradient boosting for regression
SELECT * FROM pgml.train(
project_name => 'advanced_house_prices',
task => 'regression',
relation_name => 'housing',
y_column_name => 'price',
algorithm => 'xgboost',
hyperparams => '{
"n_estimators": 200,
"max_depth": 8,
"learning_rate": 0.05,
"min_child_weight": 3,
"subsample": 0.8,
"colsample_bytree": 0.8,
"objective": "reg:squarederror"
}'
);

-- LightGBM for large datasets
SELECT * FROM pgml.train(
project_name => 'fast_house_prices',
task => 'regression',
relation_name => 'housing',
y_column_name => 'price',
algorithm => 'lightgbm',
hyperparams => '{
"n_estimators": 300,
"num_leaves": 31,
"learning_rate": 0.05,
"feature_fraction": 0.8
}'
);

Time Series Forecasting

-- Prepare time series data
CREATE TABLE sensor_readings (
timestamp TIMESTAMPTZ,
sensor_id INT,
temperature FLOAT,
humidity FLOAT,
pressure FLOAT,
next_hour_temp FLOAT -- Target: temperature in next hour
);

-- Train forecasting model
SELECT * FROM pgml.train(
project_name => 'temperature_forecast',
task => 'regression',
relation_name => 'sensor_readings',
y_column_name => 'next_hour_temp',
algorithm => 'xgboost',
hyperparams => '{
"n_estimators": 250,
"max_depth": 6,
"learning_rate": 0.1
}'
);

Clustering Tasks

K-Means Clustering

-- Customer segmentation
CREATE TABLE customers (
id SERIAL PRIMARY KEY,
age INT,
annual_income FLOAT,
spending_score INT,
purchase_frequency INT
);

SELECT * FROM pgml.train(
project_name => 'customer_segments',
task => 'clustering',
relation_name => 'customers',
algorithm => 'kmeans',
hyperparams => '{
"n_clusters": 5,
"init": "k-means++",
"n_init": 10,
"max_iter": 300
}'
);

DBSCAN Clustering

-- Density-based clustering for anomaly detection
SELECT * FROM pgml.train(
project_name => 'anomaly_clusters',
task => 'clustering',
relation_name => 'network_traffic',
algorithm => 'dbscan',
hyperparams => '{
"eps": 0.5,
"min_samples": 5
}'
);

Algorithm Selection Guide

Decision Tree

Choose tree-based algorithms for:

  • Non-linear relationships
  • Mixed data types (categorical and numerical)
  • Feature importance insights
-- Random Forest: Good all-rounder
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'random_forest',
hyperparams => '{"n_estimators": 100, "max_depth": 10}'
);

-- XGBoost: High performance, handles missing data
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'xgboost',
hyperparams => '{"n_estimators": 100, "learning_rate": 0.1}'
);

-- LightGBM: Very fast on large datasets
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'lightgbm',
hyperparams => '{"n_estimators": 100, "num_leaves": 31}'
);

Linear Models

Choose linear models for:

  • Linear relationships
  • High-dimensional data
  • Interpretability
  • Fast training and prediction
-- Logistic Regression
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'logistic_regression',
hyperparams => '{"max_iter": 1000}'
);

-- Ridge Regression (L2 regularization)
SELECT * FROM pgml.train(
'my_project', 'regression', 'data', 'target',
algorithm => 'ridge',
hyperparams => '{"alpha": 1.0}'
);

-- Lasso Regression (L1 regularization)
SELECT * FROM pgml.train(
'my_project', 'regression', 'data', 'target',
algorithm => 'lasso',
hyperparams => '{"alpha": 0.1}'
);

Neural Networks

-- Multi-layer Perceptron
SELECT * FROM pgml.train(
'my_project', 'classification', 'data', 'label',
algorithm => 'mlp_classifier',
hyperparams => '{
"hidden_layer_sizes": [100, 50],
"activation": "relu",
"max_iter": 500,
"learning_rate": "adaptive"
}'
);

Hyperparameter Optimization

Exhaustively search through hyperparameter combinations:

SELECT * FROM pgml.train(
project_name => 'optimized_model',
task => 'classification',
relation_name => 'data',
y_column_name => 'label',
algorithm => 'xgboost',
search => 'grid',
search_params => '{
"n_estimators": [50, 100, 200],
"max_depth": [3, 5, 7],
"learning_rate": [0.01, 0.1, 0.3]
}',
search_args => '{
"cv": 5,
"scoring": "accuracy"
}'
);

Sample random combinations (faster for large spaces):

SELECT * FROM pgml.train(
project_name => 'random_optimized',
task => 'regression',
relation_name => 'data',
y_column_name => 'target',
algorithm => 'random_forest',
search => 'random',
search_params => '{
"n_estimators": [50, 100, 150, 200],
"max_depth": [5, 10, 15, 20, 25],
"min_samples_split": [2, 5, 10]
}',
search_args => '{
"n_iter": 20,
"cv": 3,
"scoring": "neg_mean_squared_error"
}'
);

Cross-Validation

K-Fold Cross-Validation

-- Built into hyperparameter search
SELECT * FROM pgml.train(
project_name => 'cv_model',
task => 'classification',
relation_name => 'data',
y_column_name => 'label',
algorithm => 'random_forest',
search => 'grid',
search_params => '{
"n_estimators": [100, 200]
}',
search_args => '{
"cv": 5, -- 5-fold cross-validation
"scoring": "f1_weighted"
}'
);

Training with Complex Queries

Feature Engineering in SQL

-- Train with engineered features
SELECT * FROM pgml.train(
project_name => 'engineered_model',
task => 'classification',
relation_name => $$
SELECT
-- Original features
age,
income,
-- Engineered features
EXTRACT(YEAR FROM age(date_of_birth)) AS calculated_age,
income / NULLIF(household_size, 0) AS income_per_person,
CASE
WHEN age < 30 THEN 'young'
WHEN age < 50 THEN 'middle'
ELSE 'senior'
END AS age_group,
-- Aggregations
(SELECT AVG(purchase_amount)
FROM purchases p
WHERE p.customer_id = c.id) AS avg_purchase,
-- Target
is_premium_customer
FROM customers c
$$,
y_column_name => 'is_premium_customer',
algorithm => 'xgboost'
);

Joining Multiple Tables

SELECT * FROM pgml.train(
project_name => 'multi_table_model',
task => 'regression',
relation_name => $$
SELECT
u.age,
u.location,
COUNT(DISTINCT o.id) AS order_count,
SUM(o.total) AS lifetime_value,
AVG(r.rating) AS avg_rating,
MAX(o.created_at) AS last_order_date,
-- Target
u.predicted_churn_score
FROM users u
LEFT JOIN orders o ON u.id = o.user_id
LEFT JOIN reviews r ON u.id = r.user_id
GROUP BY u.id, u.age, u.location, u.predicted_churn_score
$$,
y_column_name => 'predicted_churn_score'
);

Model Versioning

PostgresML automatically versions models:

-- Train multiple versions
SELECT * FROM pgml.train('my_project', 'classification', 'data', 'label',
algorithm => 'random_forest');

SELECT * FROM pgml.train('my_project', 'classification', 'data', 'label',
algorithm => 'xgboost');

-- View all versions
SELECT
id,
created_at,
algorithm,
metrics->>'accuracy' AS accuracy
FROM pgml.models
WHERE project_name = 'my_project'
ORDER BY created_at DESC;

-- Deploy specific version
SELECT pgml.deploy(
project_name => 'my_project',
model_id => 42
);

Monitoring Training

Track Training Progress

-- View recent training jobs
SELECT
project_name,
algorithm,
status,
created_at,
training_time
FROM pgml.models
ORDER BY created_at DESC
LIMIT 10;

Compare Model Performance

-- Compare different algorithms
SELECT
algorithm,
AVG((metrics->>'accuracy')::float) AS avg_accuracy,
MIN((metrics->>'accuracy')::float) AS min_accuracy,
MAX((metrics->>'accuracy')::float) AS max_accuracy,
COUNT(*) AS num_models
FROM pgml.models
WHERE project_name = 'my_project'
GROUP BY algorithm
ORDER BY avg_accuracy DESC;

Best Practices

1. Start with a Baseline

-- Always train a simple baseline first
SELECT * FROM pgml.train(
'baseline_model',
'classification',
'data',
'label',
algorithm => 'logistic_regression'
);

2. Use Appropriate Metrics

-- For imbalanced classification, use F1 or AUC
SELECT * FROM pgml.train(
'balanced_model',
'classification',
'data',
'label',
algorithm => 'xgboost',
search => 'grid',
search_args => '{
"cv": 5,
"scoring": "f1_weighted"
}'
);

3. Validate with Test Data

-- Use appropriate test size
SELECT * FROM pgml.train(
'validated_model',
'classification',
'data',
'label',
test_size => 0.2, -- 20% for testing
test_sampling => 'stratified' -- Maintain class distribution
);

4. Regular Retraining

-- Retrain with fresh data periodically
SELECT * FROM pgml.train(
'production_model',
'classification',
'SELECT * FROM data WHERE created_at > NOW() - INTERVAL ''30 days''',
'label',
algorithm => 'xgboost'
);

Next Steps

Continue learning about: