点击以下链接可以一键运行例程,无需安装配置环境 ↓ ↓ ↓

Pycaret 3.0 功能抢先体验https://www.heywhale.com/mw/project/62d378a2de3942457a52c905

Pycaret简介

PyCaret 是一个开源、低代码的Python机器学习库,仅用几行代码替换数百行代码,可自动执行机器学习工作流,是一种端到端的机器学习和模型管理工具

使用手册

官方的使用手册:Welcome to PyCaret - PyCaret Official

支持功能

Pycaret 模块支持的功能有数据处理,模型训练,参数搜索,模型可解释性,模型选择,实验日志查询

Cheat Sheet

安装Pycaret 3.0 预览版

!pip install llvmlite==0.38.1 -i https://pypi.tuna.tsinghua.edu.cn/simple --ignore-installed
!pip install pycaret==3.0.0rc3 -i https://pypi.tuna.tsinghua.edu.cn/simple #--ignore-installed

查看安装结果

import pycaret
print(pycaret.__version__)
> 3.0.0.rc3

Pycaret 内置数据集

from pycaret.datasets import get_data
all_datasets = get_data('index')
all_datasets
DatasetData TypesDefault TaskTarget Variable 1Target Variable 2# Instances# AttributesMissing Values
0anomalyMultivariateAnomaly DetectionNoneNone100010N
1franceMultivariateAssociation Rule MiningInvoiceNoDescription85578N
2germanyMultivariateAssociation Rule MiningInvoiceNoDescription94958N
3bankMultivariateClassification (Binary)depositNone4521117N
4bloodMultivariateClassification (Binary)ClassNone7485N
5cancerMultivariateClassification (Binary)ClassNone68310N
6creditMultivariateClassification (Binary)defaultNone2400024N
7diabetesMultivariateClassification (Binary)Class variableNone7689N
8electrical_gridMultivariateClassification (Binary)stabfNone1000014N
9employeeMultivariateClassification (Binary)leftNone1499910N
10heartMultivariateClassification (Binary)DEATHNone20016N
11heart_diseaseMultivariateClassification (Binary)DiseaseNone27014N
12hepatitisMultivariateClassification (Binary)ClassNone15432Y
13incomeMultivariateClassification (Binary)income >50KNone3256114Y
14juiceMultivariateClassification (Binary)PurchaseNone107015N
15nbaMultivariateClassification (Binary)TARGET_5YrsNone134021N
16wineMultivariateClassification (Binary)typeNone649813N
17telescopeMultivariateClassification (Binary)ClassNone1902011N
18titanicMultivariateClassification (Binary)SurvivedNone89111Y
19us_presidential_election_resultsMultivariateClassification (Binary)party_winnerNone4977N
20glassMultivariateClassification (Multiclass)TypeNone21410N
21irisMultivariateClassification (Multiclass)speciesNone1505N
22pokerMultivariateClassification (Multiclass)CLASSNone10000011N
23questionsMultivariateClassification (Multiclass)Next_QuestionNone4994N
24satelliteMultivariateClassification (Multiclass)ClassNone643537N
25CTGMultivariateClassification (Multiclass)NSPNone212940Y
26asia_gdpMultivariateClusteringNoneNone4011N
27electionsMultivariateClusteringNoneNone319554Y
28facebookMultivariateClusteringNoneNone705012N
29iplMultivariateClusteringNoneNone15325N
30jewelleryMultivariateClusteringNoneNone5054N
31miceMultivariateClusteringNoneNone108082Y
32migrationMultivariateClusteringNoneNone23312N
33perfumeMultivariateClusteringNoneNone2029N
34pokemonMultivariateClusteringNoneNone80013Y
35populationMultivariateClusteringNoneNone25556Y
36public_healthMultivariateClusteringNoneNone22421N
37seedsMultivariateClusteringNoneNone2107N
38wholesaleMultivariateClusteringNoneNone4408N
39tweetsTextNLPtweetNone85942N
40amazonTextNLP / ClassificationreviewTextNone200002N
41kivaTextNLP / ClassificationenNone68187N
42spxTextNLP / RegressiontextNone8744N
43wikipediaTextNLP / ClassificationTextNone5003N
44automobileMultivariateRegressionpriceNone20226Y
45bikeMultivariateRegressioncntNone1737915N
46bostonMultivariateRegressionmedvNone50614N
47concreteMultivariateRegressionstrengthNone10309N
48diamondMultivariateRegressionPriceNone60008N
49energyMultivariateRegressionHeating LoadCooling Load76810N
50forestMultivariateRegressionareaNone51713N
51goldMultivariateRegressionGold_T+22None2558121N
52houseMultivariateRegressionSalePriceNone146181Y
53insuranceMultivariateRegressionchargesNone13387N
54parkinsonsMultivariateRegressionPPENone587522N
55trafficMultivariateRegressiontraffic_volumeNone482048N

>

DatasetData TypesDefault TaskTarget Variable 1Target Variable 2# Instances# AttributesMissing Values
0anomalyMultivariateAnomaly DetectionNoneNone100010N
1franceMultivariateAssociation Rule MiningInvoiceNoDescription85578N
2germanyMultivariateAssociation Rule MiningInvoiceNoDescription94958N
3bankMultivariateClassification (Binary)depositNone4521117N
4bloodMultivariateClassification (Binary)ClassNone7485N
5cancerMultivariateClassification (Binary)ClassNone68310N
6creditMultivariateClassification (Binary)defaultNone2400024N
7diabetesMultivariateClassification (Binary)Class variableNone7689N
8electrical_gridMultivariateClassification (Binary)stabfNone1000014N
9employeeMultivariateClassification (Binary)leftNone1499910N
10heartMultivariateClassification (Binary)DEATHNone20016N
11heart_diseaseMultivariateClassification (Binary)DiseaseNone27014N
12hepatitisMultivariateClassification (Binary)ClassNone15432Y
13incomeMultivariateClassification (Binary)income >50KNone3256114Y
14juiceMultivariateClassification (Binary)PurchaseNone107015N
15nbaMultivariateClassification (Binary)TARGET_5YrsNone134021N
16wineMultivariateClassification (Binary)typeNone649813N
17telescopeMultivariateClassification (Binary)ClassNone1902011N
18titanicMultivariateClassification (Binary)SurvivedNone89111Y
19us_presidential_election_resultsMultivariateClassification (Binary)party_winnerNone4977N
20glassMultivariateClassification (Multiclass)TypeNone21410N
21irisMultivariateClassification (Multiclass)speciesNone1505N
22pokerMultivariateClassification (Multiclass)CLASSNone10000011N
23questionsMultivariateClassification (Multiclass)Next_QuestionNone4994N
24satelliteMultivariateClassification (Multiclass)ClassNone643537N
25CTGMultivariateClassification (Multiclass)NSPNone212940Y
26asia_gdpMultivariateClusteringNoneNone4011N
27electionsMultivariateClusteringNoneNone319554Y
28facebookMultivariateClusteringNoneNone705012N
29iplMultivariateClusteringNoneNone15325N
30jewelleryMultivariateClusteringNoneNone5054N
31miceMultivariateClusteringNoneNone108082Y
32migrationMultivariateClusteringNoneNone23312N
33perfumeMultivariateClusteringNoneNone2029N
34pokemonMultivariateClusteringNoneNone80013Y
35populationMultivariateClusteringNoneNone25556Y
36public_healthMultivariateClusteringNoneNone22421N
37seedsMultivariateClusteringNoneNone2107N
38wholesaleMultivariateClusteringNoneNone4408N
39tweetsTextNLPtweetNone85942N
40amazonTextNLP / ClassificationreviewTextNone200002N
41kivaTextNLP / ClassificationenNone68187N
42spxTextNLP / RegressiontextNone8744N
43wikipediaTextNLP / ClassificationTextNone5003N
44automobileMultivariateRegressionpriceNone20226Y
45bikeMultivariateRegressioncntNone1737915N
46bostonMultivariateRegressionmedvNone50614N
47concreteMultivariateRegressionstrengthNone10309N
48diamondMultivariateRegressionPriceNone60008N
49energyMultivariateRegressionHeating LoadCooling Load76810N
50forestMultivariateRegressionareaNone51713N
51goldMultivariateRegressionGold_T+22None2558121N
52houseMultivariateRegressionSalePriceNone146181Y
53insuranceMultivariateRegressionchargesNone13387N
54parkinsonsMultivariateRegressionPPENone587522N
55trafficMultivariateRegressiontraffic_volumeNone482048N

Pycaret 时间序列预测

导入模块

# 导入Pycaret内置数据
from pycaret.datasets import get_data
# 导入Pycaret时间序列预测模型(3.x版本新增)
from pycaret.time_series import *

import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

读取数据

# 使用内置数据airline
data = get_data('airline')
data

Period
1949-01    112.0
1949-02    118.0
1949-03    132.0
1949-04    129.0
1949-05    121.0
Freq: M, Name: Number of airline passengers, dtype: float64

>

Period
1949-01    112.0
1949-02    118.0
1949-03    132.0
1949-04    129.0
1949-05    121.0
           ...  
1960-08    606.0
1960-09    508.0
1960-10    461.0
1960-11    390.0
1960-12    432.0
Freq: M, Name: Number of airline passengers, Length: 144, dtype: float64

数据探索分析

# 时间序列绘制
data.plot()

>

# 统计测试
check_stats()

>

TestTest NameDataPropertySettingValue
0SummaryStatisticsTransformedLength144.0
1SummaryStatisticsTransformed# Missing Values0.0
2SummaryStatisticsTransformedMean280.298611
3SummaryStatisticsTransformedMedian265.5
4SummaryStatisticsTransformedStandard Deviation119.966317
5SummaryStatisticsTransformedVariance14391.917201
6SummaryStatisticsTransformedKurtosis-0.364942
7SummaryStatisticsTransformedSkewness0.58316
8SummaryStatisticsTransformed# Distinct Values118.0
9White NoiseLjung-BoxTransformedTest Statictic{'alpha': 0.05, 'K': 24}1606.083817
10White NoiseLjung-BoxTransformedTest Statictic{'alpha': 0.05, 'K': 48}1933.155822
11White NoiseLjung-BoxTransformedp-value{'alpha': 0.05, 'K': 24}0.0
12White NoiseLjung-BoxTransformedp-value{'alpha': 0.05, 'K': 48}0.0
13White NoiseLjung-BoxTransformedWhite Noise{'alpha': 0.05, 'K': 24}False
14White NoiseLjung-BoxTransformedWhite Noise{'alpha': 0.05, 'K': 48}False
15StationarityADFTransformedStationarity{'alpha': 0.05}False
16StationarityADFTransformedp-value{'alpha': 0.05}0.99188
17StationarityADFTransformedTest Statistic{'alpha': 0.05}0.815369
18StationarityADFTransformedCritical Value 1%{'alpha': 0.05}-3.481682
19StationarityADFTransformedCritical Value 5%{'alpha': 0.05}-2.884042
20StationarityADFTransformedCritical Value 10%{'alpha': 0.05}-2.57877
21StationarityKPSSTransformedTrend Stationarity{'alpha': 0.05}True
22StationarityKPSSTransformedp-value{'alpha': 0.05}0.1
23StationarityKPSSTransformedTest Statistic{'alpha': 0.05}0.09615
24StationarityKPSSTransformedCritical Value 10%{'alpha': 0.05}0.119
25StationarityKPSSTransformedCritical Value 5%{'alpha': 0.05}0.146
26StationarityKPSSTransformedCritical Value 2.5%{'alpha': 0.05}0.176
27StationarityKPSSTransformedCritical Value 1%{'alpha': 0.05}0.216
28NormalityShapiroTransformedNormality{'alpha': 0.05}False
29NormalityShapiroTransformedp-value{'alpha': 0.05}0.000068

寻找最佳模型

# 初始化
s = setup(data, fh = 12, session_id = 123) 
# 模型比较
best = compare_models()

DescriptionValue
0session_id123
1TargetNumber of airline passengers
2ApproachUnivariate
3Exogenous VariablesNot Present
4Original data shape(144, 1)
5Transformed data shape(144, 1)
6Transformed train set shape(132, 1)
7Transformed test set shape(12, 1)
8Rows with missing values0.0%
9Fold GeneratorExpandingWindowSplitter
10Fold Number3
11Enforce Prediction IntervalFalse
12Seasonal Period(s) Tested12
13Seasonality PresentTrue
14Seasonalities Detected[12]
15Primary Seasonality12
16Target Strictly PositiveTrue
17Target White NoiseNo
18Recommended d1
19Recommended Seasonal D1
20PreprocessFalse
21CPU Jobs-1
22Use GPUFalse
23Log ExperimentFalse
24Experiment Namets-default-name
25USI98d2
ModelMASERMSSEMAERMSEMAPESMAPER2TT (Sec)
exp_smoothExponential Smoothing0.57160.599716.776719.79540.04220.04270.89540.0400
etsETS0.59310.621217.417220.51080.04400.04450.88820.0767
et_cds_dtExtra Trees w/ Cond. Deseasonalize & Detrending0.66020.728819.465324.10500.04840.04840.84590.1167
huber_cds_dtHuber w/ Cond. Deseasonalize & Detrending0.68130.786620.033425.96700.04910.04990.81130.0267
arimaARIMA0.68300.673520.006922.21990.05010.05070.86770.0900
lr_cds_dtLinear w/ Cond. Deseasonalize & Detrending0.70040.770220.608425.44010.05090.05140.82150.0267
ridge_cds_dtRidge w/ Cond. Deseasonalize & Detrending0.70040.770320.608625.44050.05090.05140.82150.0233
lar_cds_dtLeast Angular Regressor w/ Cond. Deseasonalize & Detrending0.70040.770220.608425.44010.05090.05140.82150.0233
en_cds_dtElastic Net w/ Cond. Deseasonalize & Detrending0.70290.773220.681625.53620.05110.05160.82010.0267
lasso_cds_dtLasso w/ Cond. Deseasonalize & Detrending0.70480.775120.737325.60050.05120.05170.81930.0200
catboost_cds_dtCatBoost Regressor w/ Cond. Deseasonalize & Detrending0.71060.814620.911226.89070.05050.05090.80850.9433
br_cds_dtBayesian Ridge w/ Cond. Deseasonalize & Detrending0.71120.783720.921325.87950.05150.05210.81440.0233
knn_cds_dtK Neighbors w/ Cond. Deseasonalize & Detrending0.71620.815721.161326.97000.05210.05290.78110.0300
auto_arimaAuto ARIMA0.71810.711421.029723.46610.05250.05310.85091.6967
gbr_cds_dtGradient Boosting w/ Cond. Deseasonalize & Detrending0.79380.931023.372330.73440.05690.05760.74170.0367
xgboost_cds_dtExtreme Gradient Boosting w/ Cond. Deseasonalize & Detrending0.81550.959124.073831.69500.05820.05920.7118152.0800
lightgbm_cds_dtLight Gradient Boosting w/ Cond. Deseasonalize & Detrending0.81560.911724.000230.09560.05750.05870.756186.6467
rf_cds_dtRandom Forest w/ Cond. Deseasonalize & Detrending0.83270.946524.529031.26350.06000.06060.73600.1400
ada_cds_dtAdaBoost w/ Cond. Deseasonalize & Detrending0.88251.029225.947133.93040.06190.06370.67250.0500
llar_cds_dtLasso Least Angular Regressor w/ Cond. Deseasonalize & Detrending0.96701.191528.449939.33030.06650.06930.57380.0233
thetaTheta Forecaster0.97291.030628.319233.86390.06700.07000.67100.0167
omp_cds_dtOrthogonal Matching Pursuit w/ Cond. Deseasonalize & Detrending1.00901.237029.629440.81210.06850.07180.54620.0200
dt_cds_dtDecision Tree w/ Cond. Deseasonalize & Detrending1.04291.222630.480040.19120.07260.07530.53620.0333
snaiveSeasonal Naive Forecaster1.14791.094533.361135.91390.08320.08790.60720.0133
par_cds_dtPassive Aggressive w/ Cond. Deseasonalize & Detrending1.24721.308136.772743.32150.09350.09610.49680.0233
polytrendPolynomial Trend Forecaster1.65231.920248.630163.42990.11700.1216-0.07840.0100
crostonCroston1.93112.351756.618077.58560.12950.1439-0.62810.0100
naiveNaive Forecaster2.35992.761269.027891.03220.15690.1792-1.22160.6467
grand_meansGrand Means Forecaster5.53065.2596162.4117173.64920.40000.5075-7.04620.5100
Processing:   0%|          | 0/125 [00:00<?, ?it/s]
# 查看最优模型
best

>

ExponentialSmoothing

ExponentialSmoothing(seasonal='mul', sp=12, trend='add')
exp_smooth = create_model('exp_smooth')
print(exp_smooth)

>

cutoffMASERMSSEMAERMSEMAPESMAPER2
01956-120.49850.573514.558418.77300.03660.03760.8853
11957-120.50880.536815.554818.22430.04200.04110.9130
21958-120.70750.688820.216722.38880.04790.04940.8879
MeanNaT0.57160.599716.776719.79540.04220.04270.8954
SDNaT0.09620.06472.46631.84750.00460.00490.0125
Processing:   0%|          | 0/4 [00:00<?, ?it/s]
ExponentialSmoothing(seasonal='mul', sp=12, trend='add')
# 模型预测评估
pred = predict_model(exp_smooth)

>

ModelMASERMSSEMAERMSEMAPESMAPER2
0Exponential Smoothing0.33830.457610.302315.80960.02210.02160.9549

 

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐