Making Predictions
Learn how to use trained models for inference and predictions in PostgresML.
Basic Predictions
Single Prediction
-- Predict for a single observation
SELECT pgml.predict(
'my_model',
ARRAY[feature1_value, feature2_value, feature3_value]
) AS prediction;
-- Example: Predict house price
SELECT pgml.predict(
'house_price_predictor',
ARRAY[2000.0, 3, 2.5, 2020] -- square_feet, bedrooms, bathrooms, year_built
) AS predicted_price;
Batch Predictions
-- Predict on entire table
SELECT
id,
feature1,
feature2,
pgml.predict(
'my_model',
ARRAY[feature1, feature2, feature3]
) AS prediction
FROM my_data;
-- Store predictions in a new table
CREATE TABLE predictions AS
SELECT
id,
pgml.predict('my_model', ARRAY[feature1, feature2]) AS prediction,
actual_value
FROM test_data;
Probability Predictions
Get prediction probabilities for classification:
-- Get class probabilities
SELECT pgml.predict_proba(
'spam_detector',
ARRAY[word_count, has_links::int, sender_reputation]
) AS probabilities;
-- Returns: {"spam": 0.85, "not_spam": 0.15}
-- Use in queries
SELECT
email_id,
subject,
(pgml.predict_proba(
'spam_detector',
ARRAY[word_count, has_links::int, sender_reputation]
)->'spam')::float AS spam_probability
FROM emails
WHERE (pgml.predict_proba(
'spam_detector',
ARRAY[word_count, has_links::int, sender_reputation]
)->'spam')::float > 0.9
ORDER BY spam_probability DESC;
Prediction Strategies
Real-Time Predictions
-- Predict in real-time within transactions
BEGIN;
INSERT INTO orders (user_id, product_id, quantity)
VALUES (123, 456, 2);
-- Predict delivery time
UPDATE orders
SET estimated_delivery = pgml.predict(
'delivery_predictor',
ARRAY[
(SELECT distance FROM users WHERE id = 123),
(SELECT weight FROM products WHERE id = 456) * 2
]
)
WHERE user_id = 123 AND product_id = 456;
COMMIT;
Cached Predictions
-- Pre-compute predictions for better performance
CREATE MATERIALIZED VIEW user_predictions AS
SELECT
user_id,
pgml.predict('churn_predictor', user_features) AS churn_score,
pgml.predict('ltv_predictor', user_features) AS lifetime_value
FROM (
SELECT
user_id,
ARRAY[age, tenure, purchase_count, avg_order_value] AS user_features
FROM user_stats
) features;
-- Refresh periodically
REFRESH MATERIALIZED VIEW user_predictions;
-- Fast queries
SELECT * FROM user_predictions WHERE churn_score > 0.7;
Prediction with Different Models
Using Specific Model Versions
-- Predict with latest deployed model (default)
SELECT pgml.predict('my_project', ARRAY[1, 2, 3]);
-- Predict with specific model ID
SELECT pgml.predict('my_project', ARRAY[1, 2, 3], model_id => 42);
-- Deploy and predict with specific model
SELECT pgml.deploy('my_project', model_id => 42);
SELECT pgml.predict('my_project', ARRAY[1, 2, 3]);
A/B Testing Models
-- Compare predictions from different models
WITH predictions AS (
SELECT
id,
actual_value,
pgml.predict('model_v1', features) AS pred_v1,
pgml.predict('model_v2', features) AS pred_v2
FROM test_data
)
SELECT
AVG(ABS(actual_value - pred_v1)) AS mae_v1,
AVG(ABS(actual_value - pred_v2)) AS mae_v2,
AVG(POWER(actual_value - pred_v1, 2)) AS mse_v1,
AVG(POWER(actual_value - pred_v2, 2)) AS mse_v2
FROM predictions;
Handling Edge Cases
Missing Values
-- Use COALESCE for missing values
SELECT pgml.predict(
'my_model',
ARRAY[
COALESCE(feature1, 0),
COALESCE(feature2, -1),
COALESCE(feature3, AVG(feature3) OVER ())
]
) AS prediction
FROM my_data;
Out-of-Range Values
-- Clip values to training range
WITH feature_bounds AS (
SELECT
MIN(age) AS min_age,
MAX(age) AS max_age,
MIN(income) AS min_income,
MAX(income) AS max_income
FROM training_data
)
SELECT pgml.predict(
'my_model',
ARRAY[
LEAST(GREATEST(age, min_age), max_age),
LEAST(GREATEST(income, min_income), max_income)
]
) AS prediction
FROM new_data, feature_bounds;
Prediction Performance Optimization
Indexing for Fast Lookup
-- Create index on frequently predicted columns
CREATE INDEX idx_features ON my_data (feature1, feature2, feature3);
-- Predictions will be faster
SELECT
id,
pgml.predict('my_model', ARRAY[feature1, feature2, feature3]) AS pred
FROM my_data
WHERE feature1 > 100;
Batch Processing
-- Process in batches for large tables
DO $$
DECLARE
batch_size INT := 10000;
offset_val INT := 0;
BEGIN
LOOP
INSERT INTO predictions (id, prediction)
SELECT
id,
pgml.predict('my_model', ARRAY[f1, f2, f3])
FROM my_data
ORDER BY id
LIMIT batch_size OFFSET offset_val;
EXIT WHEN NOT FOUND;
offset_val := offset_val + batch_size;
RAISE NOTICE 'Processed % rows', offset_val;
END LOOP;
END $$;
Parallel Predictions
-- Enable parallel query execution
SET max_parallel_workers_per_gather = 4;
-- Parallel predictions on large dataset
SELECT
id,
pgml.predict('my_model', ARRAY[f1, f2, f3]) AS prediction
FROM large_table
WHERE date > '2024-01-01';
Prediction Pipelines
Multi-Stage Predictions
-- Use output of one model as input to another
WITH stage1 AS (
SELECT
id,
features,
pgml.predict('preprocessor', features) AS processed_features
FROM raw_data
),
stage2 AS (
SELECT
id,
pgml.predict('classifier', processed_features) AS classification
FROM stage1
)
SELECT
id,
classification,
pgml.predict('confidence_model', ARRAY[classification]) AS confidence
FROM stage2;
Ensemble Predictions
-- Average predictions from multiple models
SELECT
id,
(
pgml.predict('model_rf', features) +
pgml.predict('model_xgb', features) +
pgml.predict('model_lgb', features)
) / 3.0 AS ensemble_prediction
FROM data;
-- Weighted ensemble
SELECT
id,
(
pgml.predict('model_rf', features) * 0.3 +
pgml.predict('model_xgb', features) * 0.5 +
pgml.predict('model_lgb', features) * 0.2
) AS weighted_prediction
FROM data;
Conditional Predictions
Rule-Based + ML Hybrid
-- Apply business rules with ML predictions
SELECT
id,
CASE
WHEN feature1 < 0 THEN 'reject'
WHEN feature2 > 1000 THEN 'approve'
ELSE pgml.predict('decision_model', ARRAY[feature1, feature2])
END AS final_decision
FROM applications;
Confidence-Based Actions
-- Take action based on prediction confidence
WITH predictions AS (
SELECT
id,
pgml.predict('fraud_detector', features) AS prediction,
pgml.predict_proba('fraud_detector', features) AS probabilities
FROM transactions
)
SELECT
id,
prediction,
CASE
WHEN (probabilities->'fraud')::float > 0.95 THEN 'block'
WHEN (probabilities->'fraud')::float > 0.70 THEN 'review'
ELSE 'approve'
END AS action
FROM predictions;
Monitoring Predictions
Prediction Logging
-- Log predictions for monitoring
CREATE TABLE prediction_logs (
id SERIAL PRIMARY KEY,
model_name TEXT,
features JSONB,
prediction FLOAT,
predicted_at TIMESTAMPTZ DEFAULT NOW()
);
-- Log each prediction
INSERT INTO prediction_logs (model_name, features, prediction)
SELECT
'my_model',
jsonb_build_object('f1', f1, 'f2', f2),
pgml.predict('my_model', ARRAY[f1, f2])
FROM new_data;
Drift Detection
-- Monitor prediction distribution over time
SELECT
DATE_TRUNC('day', predicted_at) AS day,
AVG(prediction) AS avg_prediction,
STDDEV(prediction) AS stddev_prediction,
MIN(prediction) AS min_prediction,
MAX(prediction) AS max_prediction,
COUNT(*) AS prediction_count
FROM prediction_logs
WHERE model_name = 'my_model'
GROUP BY DATE_TRUNC('day', predicted_at)
ORDER BY day DESC;
Performance Tracking
-- Compare predictions vs actuals
SELECT
DATE_TRUNC('week', created_at) AS week,
AVG(ABS(prediction - actual_value)) AS mae,
AVG(POWER(prediction - actual_value, 2)) AS mse,
CORR(prediction, actual_value) AS correlation
FROM (
SELECT
pl.predicted_at AS created_at,
pl.prediction,
a.actual_value
FROM prediction_logs pl
JOIN actuals a ON pl.id = a.prediction_id
) results
GROUP BY week
ORDER BY week DESC;
Integration Patterns
API Endpoint Pattern
-- Create a function for API calls
CREATE OR REPLACE FUNCTION predict_customer_churn(
customer_age INT,
customer_tenure INT,
monthly_charges FLOAT
) RETURNS JSONB AS $$
SELECT jsonb_build_object(
'churn_probability', pgml.predict_proba(
'churn_model',
ARRAY[customer_age, customer_tenure, monthly_charges]
),
'predicted_class', pgml.predict(
'churn_model',
ARRAY[customer_age, customer_tenure, monthly_charges]
),
'model_version', (
SELECT id FROM pgml.deployed_models
WHERE project_name = 'churn_model'
)
);
$$ LANGUAGE sql STABLE;
-- Call from application
SELECT predict_customer_churn(35, 24, 89.99);
Trigger-Based Predictions
-- Automatically predict on insert
CREATE OR REPLACE FUNCTION predict_on_insert()
RETURNS TRIGGER AS $$
BEGIN
NEW.predicted_value := pgml.predict(
'auto_predictor',
ARRAY[NEW.feature1, NEW.feature2, NEW.feature3]
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER prediction_trigger
BEFORE INSERT ON my_table
FOR EACH ROW
EXECUTE FUNCTION predict_on_insert();
-- Insert automatically includes prediction
INSERT INTO my_table (feature1, feature2, feature3)
VALUES (1.0, 2.0, 3.0);
-- predicted_value is automatically filled
Best Practices
1. Use Appropriate Model Versions
-- Always specify which model version for production
SELECT pgml.deploy('production_model', strategy => 'best_score');
2. Handle Errors Gracefully
-- Wrap predictions in error handling
SELECT
id,
COALESCE(
pgml.predict('my_model', ARRAY[f1, f2]),
default_value
) AS prediction
FROM data;
3. Monitor Prediction Latency
-- Track prediction time
EXPLAIN ANALYZE
SELECT pgml.predict('my_model', ARRAY[1, 2, 3]);
4. Keep Features Consistent
-- Use a view to ensure feature consistency
CREATE VIEW prediction_features AS
SELECT
id,
ARRAY[
feature1,
feature2,
LOG(feature3 + 1),
EXTRACT(EPOCH FROM date_col)
] AS features
FROM raw_data;
-- Always predict from this view
SELECT
id,
pgml.predict('my_model', features) AS prediction
FROM prediction_features;
Next Steps
Explore more advanced topics:
- Transformers & NLP - Natural language processing
- Vector Operations - Embeddings and semantic search
- Advanced Examples - Real-world use cases