Skip to main content

Making Predictions

Learn how to use trained models for inference and predictions in PostgresML.

Basic Predictions

Single Prediction

-- Predict for a single observation
SELECT pgml.predict(
'my_model',
ARRAY[feature1_value, feature2_value, feature3_value]
) AS prediction;

-- Example: Predict house price
SELECT pgml.predict(
'house_price_predictor',
ARRAY[2000.0, 3, 2.5, 2020] -- square_feet, bedrooms, bathrooms, year_built
) AS predicted_price;

Batch Predictions

-- Predict on entire table
SELECT
id,
feature1,
feature2,
pgml.predict(
'my_model',
ARRAY[feature1, feature2, feature3]
) AS prediction
FROM my_data;

-- Store predictions in a new table
CREATE TABLE predictions AS
SELECT
id,
pgml.predict('my_model', ARRAY[feature1, feature2]) AS prediction,
actual_value
FROM test_data;

Probability Predictions

Get prediction probabilities for classification:

-- Get class probabilities
SELECT pgml.predict_proba(
'spam_detector',
ARRAY[word_count, has_links::int, sender_reputation]
) AS probabilities;

-- Returns: {"spam": 0.85, "not_spam": 0.15}

-- Use in queries
SELECT
email_id,
subject,
(pgml.predict_proba(
'spam_detector',
ARRAY[word_count, has_links::int, sender_reputation]
)->'spam')::float AS spam_probability
FROM emails
WHERE (pgml.predict_proba(
'spam_detector',
ARRAY[word_count, has_links::int, sender_reputation]
)->'spam')::float > 0.9
ORDER BY spam_probability DESC;

Prediction Strategies

Real-Time Predictions

-- Predict in real-time within transactions
BEGIN;

INSERT INTO orders (user_id, product_id, quantity)
VALUES (123, 456, 2);

-- Predict delivery time
UPDATE orders
SET estimated_delivery = pgml.predict(
'delivery_predictor',
ARRAY[
(SELECT distance FROM users WHERE id = 123),
(SELECT weight FROM products WHERE id = 456) * 2
]
)
WHERE user_id = 123 AND product_id = 456;

COMMIT;

Cached Predictions

-- Pre-compute predictions for better performance
CREATE MATERIALIZED VIEW user_predictions AS
SELECT
user_id,
pgml.predict('churn_predictor', user_features) AS churn_score,
pgml.predict('ltv_predictor', user_features) AS lifetime_value
FROM (
SELECT
user_id,
ARRAY[age, tenure, purchase_count, avg_order_value] AS user_features
FROM user_stats
) features;

-- Refresh periodically
REFRESH MATERIALIZED VIEW user_predictions;

-- Fast queries
SELECT * FROM user_predictions WHERE churn_score > 0.7;

Prediction with Different Models

Using Specific Model Versions

-- Predict with latest deployed model (default)
SELECT pgml.predict('my_project', ARRAY[1, 2, 3]);

-- Predict with specific model ID
SELECT pgml.predict('my_project', ARRAY[1, 2, 3], model_id => 42);

-- Deploy and predict with specific model
SELECT pgml.deploy('my_project', model_id => 42);
SELECT pgml.predict('my_project', ARRAY[1, 2, 3]);

A/B Testing Models

-- Compare predictions from different models
WITH predictions AS (
SELECT
id,
actual_value,
pgml.predict('model_v1', features) AS pred_v1,
pgml.predict('model_v2', features) AS pred_v2
FROM test_data
)
SELECT
AVG(ABS(actual_value - pred_v1)) AS mae_v1,
AVG(ABS(actual_value - pred_v2)) AS mae_v2,
AVG(POWER(actual_value - pred_v1, 2)) AS mse_v1,
AVG(POWER(actual_value - pred_v2, 2)) AS mse_v2
FROM predictions;

Handling Edge Cases

Missing Values

-- Use COALESCE for missing values
SELECT pgml.predict(
'my_model',
ARRAY[
COALESCE(feature1, 0),
COALESCE(feature2, -1),
COALESCE(feature3, AVG(feature3) OVER ())
]
) AS prediction
FROM my_data;

Out-of-Range Values

-- Clip values to training range
WITH feature_bounds AS (
SELECT
MIN(age) AS min_age,
MAX(age) AS max_age,
MIN(income) AS min_income,
MAX(income) AS max_income
FROM training_data
)
SELECT pgml.predict(
'my_model',
ARRAY[
LEAST(GREATEST(age, min_age), max_age),
LEAST(GREATEST(income, min_income), max_income)
]
) AS prediction
FROM new_data, feature_bounds;

Prediction Performance Optimization

Indexing for Fast Lookup

-- Create index on frequently predicted columns
CREATE INDEX idx_features ON my_data (feature1, feature2, feature3);

-- Predictions will be faster
SELECT
id,
pgml.predict('my_model', ARRAY[feature1, feature2, feature3]) AS pred
FROM my_data
WHERE feature1 > 100;

Batch Processing

-- Process in batches for large tables
DO $$
DECLARE
batch_size INT := 10000;
offset_val INT := 0;
BEGIN
LOOP
INSERT INTO predictions (id, prediction)
SELECT
id,
pgml.predict('my_model', ARRAY[f1, f2, f3])
FROM my_data
ORDER BY id
LIMIT batch_size OFFSET offset_val;

EXIT WHEN NOT FOUND;
offset_val := offset_val + batch_size;
RAISE NOTICE 'Processed % rows', offset_val;
END LOOP;
END $$;

Parallel Predictions

-- Enable parallel query execution
SET max_parallel_workers_per_gather = 4;

-- Parallel predictions on large dataset
SELECT
id,
pgml.predict('my_model', ARRAY[f1, f2, f3]) AS prediction
FROM large_table
WHERE date > '2024-01-01';

Prediction Pipelines

Multi-Stage Predictions

-- Use output of one model as input to another
WITH stage1 AS (
SELECT
id,
features,
pgml.predict('preprocessor', features) AS processed_features
FROM raw_data
),
stage2 AS (
SELECT
id,
pgml.predict('classifier', processed_features) AS classification
FROM stage1
)
SELECT
id,
classification,
pgml.predict('confidence_model', ARRAY[classification]) AS confidence
FROM stage2;

Ensemble Predictions

-- Average predictions from multiple models
SELECT
id,
(
pgml.predict('model_rf', features) +
pgml.predict('model_xgb', features) +
pgml.predict('model_lgb', features)
) / 3.0 AS ensemble_prediction
FROM data;

-- Weighted ensemble
SELECT
id,
(
pgml.predict('model_rf', features) * 0.3 +
pgml.predict('model_xgb', features) * 0.5 +
pgml.predict('model_lgb', features) * 0.2
) AS weighted_prediction
FROM data;

Conditional Predictions

Rule-Based + ML Hybrid

-- Apply business rules with ML predictions
SELECT
id,
CASE
WHEN feature1 < 0 THEN 'reject'
WHEN feature2 > 1000 THEN 'approve'
ELSE pgml.predict('decision_model', ARRAY[feature1, feature2])
END AS final_decision
FROM applications;

Confidence-Based Actions

-- Take action based on prediction confidence
WITH predictions AS (
SELECT
id,
pgml.predict('fraud_detector', features) AS prediction,
pgml.predict_proba('fraud_detector', features) AS probabilities
FROM transactions
)
SELECT
id,
prediction,
CASE
WHEN (probabilities->'fraud')::float > 0.95 THEN 'block'
WHEN (probabilities->'fraud')::float > 0.70 THEN 'review'
ELSE 'approve'
END AS action
FROM predictions;

Monitoring Predictions

Prediction Logging

-- Log predictions for monitoring
CREATE TABLE prediction_logs (
id SERIAL PRIMARY KEY,
model_name TEXT,
features JSONB,
prediction FLOAT,
predicted_at TIMESTAMPTZ DEFAULT NOW()
);

-- Log each prediction
INSERT INTO prediction_logs (model_name, features, prediction)
SELECT
'my_model',
jsonb_build_object('f1', f1, 'f2', f2),
pgml.predict('my_model', ARRAY[f1, f2])
FROM new_data;

Drift Detection

-- Monitor prediction distribution over time
SELECT
DATE_TRUNC('day', predicted_at) AS day,
AVG(prediction) AS avg_prediction,
STDDEV(prediction) AS stddev_prediction,
MIN(prediction) AS min_prediction,
MAX(prediction) AS max_prediction,
COUNT(*) AS prediction_count
FROM prediction_logs
WHERE model_name = 'my_model'
GROUP BY DATE_TRUNC('day', predicted_at)
ORDER BY day DESC;

Performance Tracking

-- Compare predictions vs actuals
SELECT
DATE_TRUNC('week', created_at) AS week,
AVG(ABS(prediction - actual_value)) AS mae,
AVG(POWER(prediction - actual_value, 2)) AS mse,
CORR(prediction, actual_value) AS correlation
FROM (
SELECT
pl.predicted_at AS created_at,
pl.prediction,
a.actual_value
FROM prediction_logs pl
JOIN actuals a ON pl.id = a.prediction_id
) results
GROUP BY week
ORDER BY week DESC;

Integration Patterns

API Endpoint Pattern

-- Create a function for API calls
CREATE OR REPLACE FUNCTION predict_customer_churn(
customer_age INT,
customer_tenure INT,
monthly_charges FLOAT
) RETURNS JSONB AS $$
SELECT jsonb_build_object(
'churn_probability', pgml.predict_proba(
'churn_model',
ARRAY[customer_age, customer_tenure, monthly_charges]
),
'predicted_class', pgml.predict(
'churn_model',
ARRAY[customer_age, customer_tenure, monthly_charges]
),
'model_version', (
SELECT id FROM pgml.deployed_models
WHERE project_name = 'churn_model'
)
);
$$ LANGUAGE sql STABLE;

-- Call from application
SELECT predict_customer_churn(35, 24, 89.99);

Trigger-Based Predictions

-- Automatically predict on insert
CREATE OR REPLACE FUNCTION predict_on_insert()
RETURNS TRIGGER AS $$
BEGIN
NEW.predicted_value := pgml.predict(
'auto_predictor',
ARRAY[NEW.feature1, NEW.feature2, NEW.feature3]
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

CREATE TRIGGER prediction_trigger
BEFORE INSERT ON my_table
FOR EACH ROW
EXECUTE FUNCTION predict_on_insert();

-- Insert automatically includes prediction
INSERT INTO my_table (feature1, feature2, feature3)
VALUES (1.0, 2.0, 3.0);
-- predicted_value is automatically filled

Best Practices

1. Use Appropriate Model Versions

-- Always specify which model version for production
SELECT pgml.deploy('production_model', strategy => 'best_score');

2. Handle Errors Gracefully

-- Wrap predictions in error handling
SELECT
id,
COALESCE(
pgml.predict('my_model', ARRAY[f1, f2]),
default_value
) AS prediction
FROM data;

3. Monitor Prediction Latency

-- Track prediction time
EXPLAIN ANALYZE
SELECT pgml.predict('my_model', ARRAY[1, 2, 3]);

4. Keep Features Consistent

-- Use a view to ensure feature consistency
CREATE VIEW prediction_features AS
SELECT
id,
ARRAY[
feature1,
feature2,
LOG(feature3 + 1),
EXTRACT(EPOCH FROM date_col)
] AS features
FROM raw_data;

-- Always predict from this view
SELECT
id,
pgml.predict('my_model', features) AS prediction
FROM prediction_features;

Next Steps

Explore more advanced topics: