工具系列:TensorFlow Decision Forests_(1)构建、训练和评估模型
决策森林(DF)是一类用于监督分类、回归和排序的机器学习算法。顾名思义,DF使用决策树作为构建块。如今,最流行的DF训练算法是随机森林和梯度提升决策树。TensorFlow决策森林(TF-DF)是一个用于训练、评估、解释和推断决策森林模型的库。在包含数值、分类和缺失特征的数据集上训练一个多类分类随机森林模型。在测试数据集上评估模型。准备模型以供使用。检查模型的整体结构和每个特征的重要性。使用不同的
文章目录
1. 介绍
决策森林(DF)是一类用于监督分类、回归和排序的机器学习算法。顾名思义,DF使用决策树作为构建块。如今,最流行的DF训练算法是随机森林和梯度提升决策树。
TensorFlow决策森林(TF-DF)是一个用于训练、评估、解释和推断决策森林模型的库。
在本教程中,您将学习如何:
- 在包含数值、分类和缺失特征的数据集上训练一个多类分类随机森林模型。
- 在测试数据集上评估模型。
- 准备模型以供TensorFlow Serving使用。
- 检查模型的整体结构和每个特征的重要性。
- 使用不同的学习算法(梯度提升决策树)重新训练模型。
- 使用不同的输入特征集。
- 更改模型的超参数。
- 预处理特征。
- 训练一个回归模型。
2. 安装 TensorFlow Decision Forests
通过运行以下单元格来安装 TF-DF。
# 安装tensorflow_decision_forests库
!pip install tensorflow_decision_forests
Collecting tensorflow_decision_forests
Obtaining dependency information for tensorflow_decision_forests from https://files.pythonhosted.org/packages/67/84/dc181dc6d4ec2692432bb168119e932a3175ffcfddcca41bc8a1a6d5a8b9/tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
Downloading tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.24.3)
Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.0.3)
Requirement already satisfied: tensorflow~=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.13.0)
Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)
Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.4.0)
Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.40.0)
Collecting wurlitzer (from tensorflow_decision_forests)
Using cached wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)
Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.6.3)
Requirement already satisfied: flatbuffers>=23.1.21 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (23.5.26)
Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.4.0)
Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.2.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.57.0)
Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.9.0)
Requirement already satisfied: keras<2.14,>=2.13.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.1)
Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (16.0.6)
Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.3.0)
Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (23.1)
Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.20.3)
Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (68.1.0)
Requirement already satisfied: tensorboard<2.14,>=2.13 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.0)
Requirement already satisfied: tensorflow-estimator<2.14,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.0)
Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.3.0)
Requirement already satisfied: typing-extensions<4.6.0,>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (4.5.0)
Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.15.0)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.33.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2023.3)
Requirement already satisfied: tzdata>=2022.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2023.3)
Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.22.0)
Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.0.0)
Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.4.4)
Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.31.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.7.1)
Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.3.7)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (5.3.1)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.3.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (4.9)
Requirement already satisfied: urllib3<2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.26.16)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (6.8.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.2.0)
Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2023.7.22)
Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.1.3)
Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.16.2)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.5.0)
Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.2.2)
Using cached tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB)
Installing collected packages: wurlitzer, tensorflow_decision_forests
Successfully installed tensorflow_decision_forests-1.5.0 wurlitzer-3.0.3
Wurlitzer 是在 Colabs 中显示详细的训练日志所需的(当在模型构造函数中使用 verbose=2
时)。
!pip install wurlitzer
Requirement already satisfied: wurlitzer in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.0.3)
3. 导入库
# 导入所需的库
import tensorflow_decision_forests as tfdf # 导入决策森林库
import os # 导入操作系统库
import numpy as np # 导入数值计算库
import pandas as pd # 导入数据处理库
import tensorflow as tf # 导入深度学习库
import math # 导入数学库
在Colab中,隐藏的代码单元格会限制输出的高度。
#@title
# 导入所需的模块
from IPython.core.magic import register_line_magic
from IPython.display import Javascript
from IPython.display import display as ipy_display
# 定义一个魔术命令,用于设置单元格的最大高度
@register_line_magic
def set_cell_height(size):
# 调用Javascript代码,设置单元格的最大高度
ipy_display(
Javascript("google.colab.output.setIframeHeight(0, true, {maxHeight: " +
str(size) + "})"))
# 检查 TensorFlow Decision Forests 的版本
# 打印出 TensorFlow Decision Forests 的版本号
print("Found TensorFlow Decision Forests v" + tfdf.__version__)
Found TensorFlow Decision Forests v1.5.0
4. 训练一个随机森林模型
在本节中,我们将训练、评估、分析和导出一个基于Palmer’s Penguins数据集的多类分类随机森林模型。
注意: 数据集被导出为一个未经预处理的csv文件:library(palmerpenguins); write.csv(penguins, file="penguins.csv", quote=F, row.names=F)
。
4.1 加载数据集并将其转换为tf.Dataset
这个数据集非常小(300个例子),并且以类似.csv的文件格式存储。因此,使用Pandas来加载它。
**注意:**Pandas很实用,因为你不需要输入特征的名称来加载它们。对于更大的数据集(>1M个例子),使用TensorFlow Dataset来读取文件可能更合适。
让我们将数据集组装成一个csv文件(即添加头部),然后加载它:
# 下载数据集
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv
# 将数据集加载到Pandas Dataframe中
dataset_df = pd.read_csv("/tmp/penguins.csv")
# 显示前3个样本
dataset_df.head(3)
species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | year | |
---|---|---|---|---|---|---|---|---|
0 | Adelie | Torgersen | 39.1 | 18.7 | 181.0 | 3750.0 | male | 2007 |
1 | Adelie | Torgersen | 39.5 | 17.4 | 186.0 | 3800.0 | female | 2007 |
2 | Adelie | Torgersen | 40.3 | 18.0 | 195.0 | 3250.0 | female | 2007 |
该数据集包含数字型(例如bill_depth_mm
)、分类型(例如island
)和缺失特征的混合。TF-DF原生支持所有这些特征类型(与基于NN的模型不同),因此不需要进行预处理,如独热编码、归一化或额外的is_present
特征。
标签有点不同:Keras指标需要整数。标签(species
)存储为字符串,因此让我们将其转换为整数。
# 将分类标签编码为整数
# 详细说明:
# 如果你的分类标签是字符串形式的,那么这个步骤是必要的,因为Keras期望的是整数分类标签。
# 当使用`pd_dataframe_to_tf_dataset`(见下文)时,可以跳过这一步。
# 标签列的名称
label = "species"
# 获取标签的唯一值,并转换为列表
classes = dataset_df[label].unique().tolist()
# 打印标签的类别
print(f"Label classes: {classes}")
# 使用类别的索引值替换数据集中的标签值
dataset_df[label] = dataset_df[label].map(classes.index)
Label classes: ['Adelie', 'Gentoo', 'Chinstrap']
接下来将数据集分为训练集和测试集:
# 将数据集分割为训练集和测试集。
def split_dataset(dataset, test_ratio=0.30):
"""将panda数据框分割成两部分。"""
# 生成一个与数据集长度相同的随机数组,元素值小于测试比例的为True,大于等于测试比例的为False
test_indices = np.random.rand(len(dataset)) < test_ratio
# 返回测试集和训练集
return dataset[~test_indices], dataset[test_indices]
# 调用split_dataset函数将数据集分割成训练集和测试集,并将返回的结果分别赋值给train_ds_pd和test_ds_pd
train_ds_pd, test_ds_pd = split_dataset(dataset_df)
# 打印训练集和测试集的样本数量
print("{}个样本用于训练,{}个样本用于测试。".format(
len(train_ds_pd), len(test_ds_pd)))
239 examples in training, 105 examples for testing.
最后,将pandas数据帧(pd.Dataframe
)转换为tensorflow数据集(tf.data.Dataset
):
# 将Pandas DataFrame转换为TensorFlow数据集
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label) # 将训练集的Pandas DataFrame转换为TensorFlow数据集,并指定标签列
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label) # 将测试集的Pandas DataFrame转换为TensorFlow数据集,并指定标签列
**注意事项:**请记住,如果需要,pd_dataframe_to_tf_dataset
会将字符串标签转换为整数。
如果您想自己创建tf.data.Dataset
,请记住以下几点:
- 学习算法使用的是一个周期的数据集,且不进行洗牌。
- 批次大小不会影响训练算法,但较小的值可能会减慢读取数据集的速度。
4.2 训练模型
# 设置单元格高度为300
# 指定模型为随机森林模型,并设置详细程度为2
model_1 = tfdf.keras.RandomForestModel(verbose=2)
# 训练模型
model_1.fit(train_ds)
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpblfnf8hv as temporary training directory
Reading training dataset...
Training tensor examples:
Features: {'island': <tf.Tensor 'data:0' shape=(None,) dtype=string>, 'bill_length_mm': <tf.Tensor 'data_1:0' shape=(None,) dtype=float64>, 'bill_depth_mm': <tf.Tensor 'data_2:0' shape=(None,) dtype=float64>, 'flipper_length_mm': <tf.Tensor 'data_3:0' shape=(None,) dtype=float64>, 'body_mass_g': <tf.Tensor 'data_4:0' shape=(None,) dtype=float64>, 'sex': <tf.Tensor 'data_5:0' shape=(None,) dtype=string>, 'year': <tf.Tensor 'data_6:0' shape=(None,) dtype=int64>}
Label: Tensor("data_7:0", shape=(None,), dtype=int64)
Weights: None
Normalized tensor features:
{'island': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data:0' shape=(None,) dtype=string>), 'bill_length_mm': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast:0' shape=(None,) dtype=float32>), 'bill_depth_mm': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_1:0' shape=(None,) dtype=float32>), 'flipper_length_mm': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_2:0' shape=(None,) dtype=float32>), 'body_mass_g': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_3:0' shape=(None,) dtype=float32>), 'sex': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_5:0' shape=(None,) dtype=string>), 'year': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_4:0' shape=(None,) dtype=float32>)}
Training dataset read in 0:00:03.556705. Found 239 examples.
Training model...
Standard output detected as not visible to the user e.g. running in a notebook. Creating a training log redirection. If training gets stuck, try calling tfdf.keras.set_training_logs_redirection(False).
[INFO 23-08-16 11:05:20.8059 UTC kernel.cc:773] Start Yggdrasil model training
[INFO 23-08-16 11:05:20.8059 UTC kernel.cc:774] Collect training examples
[INFO 23-08-16 11:05:20.8060 UTC kernel.cc:787] Dataspec guide:
column_guides {
column_name_pattern: "^__LABEL$"
type: CATEGORICAL
categorial {
min_vocab_frequency: 0
max_vocab_count: -1
}
}
default_column_guide {
categorial {
max_vocab_count: 2000
}
discretized_numerical {
maximum_num_bins: 255
}
}
ignore_columns_without_guides: false
detect_numerical_as_discretized_numerical: false
[INFO 23-08-16 11:05:20.8063 UTC kernel.cc:393] Number of batches: 1
[INFO 23-08-16 11:05:20.8064 UTC kernel.cc:394] Number of examples: 239
[INFO 23-08-16 11:05:20.8064 UTC kernel.cc:794] Training dataset:
Number of records: 239
Number of columns: 8
Number of columns by type:
NUMERICAL: 5 (62.5%)
CATEGORICAL: 3 (37.5%)
Columns:
NUMERICAL: 5 (62.5%)
1: "bill_depth_mm" NUMERICAL num-nas:1 (0.41841%) mean:17.0387 min:13.2 max:21.5 sd:1.97169
2: "bill_length_mm" NUMERICAL num-nas:1 (0.41841%) mean:44.0025 min:32.1 max:55.9 sd:5.27172
3: "body_mass_g" NUMERICAL num-nas:1 (0.41841%) mean:4230.57 min:2700 max:6300 sd:821.055
4: "flipper_length_mm" NUMERICAL num-nas:1 (0.41841%) mean:201.176 min:172 max:231 sd:14.2924
7: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.807521
CATEGORICAL: 3 (37.5%)
0: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item
5: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 121 (50.6276%)
6: "sex" CATEGORICAL num-nas:7 (2.92887%) has-dict vocab-size:3 zero-ood-items most-frequent:"female" 120 (51.7241%)
Terminology:
nas: Number of non-available (i.e. missing) values.
ood: Out of dictionary.
manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
tokenized: The attribute value is obtained through tokenization.
has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
vocab-size: Number of unique values.
[INFO 23-08-16 11:05:20.8065 UTC kernel.cc:810] Configure learner
[INFO 23-08-16 11:05:20.8067 UTC kernel.cc:824] Training config:
learner: "RANDOM_FOREST"
features: "^bill_depth_mm$"
features: "^bill_length_mm$"
features: "^body_mass_g$"
features: "^flipper_length_mm$"
features: "^island$"
features: "^sex$"
features: "^year$"
label: "^__LABEL$"
task: CLASSIFICATION
random_seed: 123456
metadata {
framework: "TF Keras"
}
pure_serving_model: false
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
num_trees: 300
decision_tree {
max_depth: 16
min_examples: 5
in_split_min_examples_check: true
keep_non_leaf_label_distribution: true
num_candidate_attributes: 0
missing_value_policy: GLOBAL_IMPUTATION
allow_na_conditions: false
categorical_set_greedy_forward {
sampling: 0.1
max_num_items: -1
min_item_frequency: 1
}
growing_strategy_local {
}
categorical {
cart {
}
}
axis_aligned_split {
}
internal {
sorting_strategy: PRESORTED
}
uplift {
min_examples_in_treatment: 5
split_score: KULLBACK_LEIBLER
}
}
winner_take_all_inference: true
compute_oob_performances: true
compute_oob_variable_importances: false
num_oob_variable_importances_permutations: 1
bootstrap_training_dataset: true
bootstrap_size_ratio: 1
adapt_bootstrap_size_ratio_for_maximum_training_duration: false
sampling_with_replacement: true
}
[INFO 23-08-16 11:05:20.8070 UTC kernel.cc:827] Deployment config:
cache_path: "/tmpfs/tmp/tmpblfnf8hv/working_cache"
num_threads: 32
try_resume_training: true
[INFO 23-08-16 11:05:20.8072 UTC kernel.cc:889] Train model
[INFO 23-08-16 11:05:20.8073 UTC random_forest.cc:416] Training random forest on 239 example(s) and 7 feature(s).
[INFO 23-08-16 11:05:20.8130 UTC random_forest.cc:802] Training of tree 1/300 (tree index:0) done accuracy:0.943182 logloss:2.04793
[INFO 23-08-16 11:05:20.8139 UTC random_forest.cc:802] Training of tree 11/300 (tree index:8) done accuracy:0.949367 logloss:0.383614
[INFO 23-08-16 11:05:20.8144 UTC random_forest.cc:802] Training of tree 21/300 (tree index:4) done accuracy:0.953975 logloss:0.386135
[INFO 23-08-16 11:05:20.8146 UTC random_forest.cc:802] Training of tree 35/300 (tree index:20) done accuracy:0.953975 logloss:0.249595
[INFO 23-08-16 11:05:20.8147 UTC random_forest.cc:802] Training of tree 50/300 (tree index:30) done accuracy:0.949791 logloss:0.249004
[INFO 23-08-16 11:05:20.8149 UTC random_forest.cc:802] Training of tree 62/300 (tree index:61) done accuracy:0.949791 logloss:0.247371
[INFO 23-08-16 11:05:20.8155 UTC random_forest.cc:802] Training of tree 73/300 (tree index:73) done accuracy:0.962343 logloss:0.246108
[INFO 23-08-16 11:05:20.8158 UTC random_forest.cc:802] Training of tree 83/300 (tree index:82) done accuracy:0.958159 logloss:0.240771
[INFO 23-08-16 11:05:20.8163 UTC random_forest.cc:802] Training of tree 96/300 (tree index:98) done accuracy:0.962343 logloss:0.0994905
[INFO 23-08-16 11:05:20.8166 UTC random_forest.cc:802] Training of tree 106/300 (tree index:105) done accuracy:0.966527 logloss:0.100095
[INFO 23-08-16 11:05:20.8170 UTC random_forest.cc:802] Training of tree 117/300 (tree index:117) done accuracy:0.962343 logloss:0.0959006
[INFO 23-08-16 11:05:20.8173 UTC random_forest.cc:802] Training of tree 127/300 (tree index:125) done accuracy:0.958159 logloss:0.0962165
[INFO 23-08-16 11:05:20.8177 UTC random_forest.cc:802] Training of tree 138/300 (tree index:137) done accuracy:0.958159 logloss:0.0927663
[INFO 23-08-16 11:05:20.8182 UTC random_forest.cc:802] Training of tree 148/300 (tree index:147) done accuracy:0.966527 logloss:0.0931921
[INFO 23-08-16 11:05:20.8187 UTC random_forest.cc:802] Training of tree 158/300 (tree index:157) done accuracy:0.966527 logloss:0.092117
[INFO 23-08-16 11:05:20.8190 UTC random_forest.cc:802] Training of tree 170/300 (tree index:170) done accuracy:0.966527 logloss:0.0926436
[INFO 23-08-16 11:05:20.8196 UTC random_forest.cc:802] Training of tree 180/300 (tree index:181) done accuracy:0.966527 logloss:0.0927239
[INFO 23-08-16 11:05:20.8200 UTC random_forest.cc:802] Training of tree 190/300 (tree index:187) done accuracy:0.966527 logloss:0.0942833
[INFO 23-08-16 11:05:20.8203 UTC random_forest.cc:802] Training of tree 200/300 (tree index:198) done accuracy:0.966527 logloss:0.0941766
[INFO 23-08-16 11:05:20.8208 UTC random_forest.cc:802] Training of tree 210/300 (tree index:208) done accuracy:0.962343 logloss:0.0938748
[INFO 23-08-16 11:05:20.8211 UTC random_forest.cc:802] Training of tree 220/300 (tree index:219) done accuracy:0.958159 logloss:0.0950461
[INFO 23-08-16 11:05:20.8214 UTC random_forest.cc:802] Training of tree 231/300 (tree index:231) done accuracy:0.953975 logloss:0.0951599
[INFO 23-08-16 11:05:20.8218 UTC random_forest.cc:802] Training of tree 241/300 (tree index:241) done accuracy:0.962343 logloss:0.0948531
[INFO 23-08-16 11:05:20.8221 UTC random_forest.cc:802] Training of tree 251/300 (tree index:250) done accuracy:0.962343 logloss:0.0942377
[INFO 23-08-16 11:05:20.8224 UTC random_forest.cc:802] Training of tree 262/300 (tree index:261) done accuracy:0.962343 logloss:0.0940229
[INFO 23-08-16 11:05:20.8228 UTC random_forest.cc:802] Training of tree 272/300 (tree index:276) done accuracy:0.958159 logloss:0.0934476
[INFO 23-08-16 11:05:20.8231 UTC random_forest.cc:802] Training of tree 282/300 (tree index:281) done accuracy:0.958159 logloss:0.0934649
[INFO 23-08-16 11:05:20.8234 UTC random_forest.cc:802] Training of tree 292/300 (tree index:292) done accuracy:0.958159 logloss:0.0943068
[INFO 23-08-16 11:05:20.8236 UTC random_forest.cc:802] Training of tree 300/300 (tree index:299) done accuracy:0.958159 logloss:0.0945677
[INFO 23-08-16 11:05:20.8250 UTC random_forest.cc:882] Final OOB metrics: accuracy:0.958159 logloss:0.0945677
[INFO 23-08-16 11:05:20.8261 UTC kernel.cc:926] Export model in log directory: /tmpfs/tmp/tmpblfnf8hv with prefix 3b862cbea45f4b2a
[INFO 23-08-16 11:05:20.8303 UTC kernel.cc:944] Save model in resources
[INFO 23-08-16 11:05:20.8335 UTC abstract_model.cc:849] Model self evaluation:
Number of predictions (without weights): 239
Number of predictions (with weights): 239
Task: CLASSIFICATION
Label: __LABEL
Accuracy: 0.958159 CI95[W][0.930062 0.977127]
LogLoss: : 0.0945677
ErrorRate: : 0.041841
Default Accuracy: : 0.422594
Default LogLoss: : 1.04864
Default ErrorRate: : 0.577406
Confusion Table:
truth\prediction
0 1 2 3
0 0 0 0 0
1 0 98 0 3
2 0 1 91 0
3 0 4 2 40
Total: 239
One vs other classes:
[INFO 23-08-16 11:05:20.8441 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpblfnf8hv/model/ with prefix 3b862cbea45f4b2a
[INFO 23-08-16 11:05:20.8582 UTC decision_forest.cc:660] Model loaded with 300 root(s), 4336 node(s), and 7 input feature(s).
[INFO 23-08-16 11:05:20.8582 UTC abstract_model.cc:1311] Engine "RandomForestGeneric" built
[INFO 23-08-16 11:05:20.8582 UTC kernel.cc:1075] Use fast generic engine
Model trained in 0:00:00.059670
Compiling model...
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f823d704ee0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f823d704ee0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f823d704ee0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Model compiled.
<keras.src.callbacks.History at 0x7f83315d9b80>
4.3 备注
- 没有指定输入特征。因此,除了标签之外,所有列都将被用作输入特征。模型使用的特征在训练日志和
model.summary()
中显示。 - DFs原生支持数值、分类、分类集和缺失值特征。数值特征不需要进行归一化。分类字符串值不需要在字典中进行编码。
- 没有指定训练超参数。因此将使用默认超参数。默认超参数在大多数情况下提供合理的结果。
- 在
fit
之前对模型调用compile
是可选的。编译可以用于提供额外的评估指标。 - 训练算法不需要验证数据集。如果提供了验证数据集,它只会用于显示指标。
- 调整
RandomForestModel
的verbose
参数以控制显示的训练日志的数量。设置verbose=0
以隐藏大部分日志。设置verbose=2
以显示所有日志。
注意: 分类集特征由一组分类值组成(而分类只是一个值)。更多详细信息和示例将在后面给出。
5. 评估模型
让我们在测试数据集上评估我们的模型。
# 编译模型
model_1.compile(metrics=["accuracy"])
# 评估模型
evaluation = model_1.evaluate(test_ds, return_dict=True)
print()
# 打印评估结果
for name, value in evaluation.items():
print(f"{name}: {value:.4f}")
1/1 [==============================] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9619
1/1 [==============================] - 0s 295ms/step - loss: 0.0000e+00 - accuracy: 0.9619
loss: 0.0000
accuracy: 0.9619
备注: 测试准确率接近于训练日志中显示的袋外准确率。
有关更多评估方法,请参见下面的模型自我评估部分。
6. 为TensorFlow Serving准备这个模型。
导出模型为SavedModel格式,以便以后重复使用,例如TensorFlow Serving。
# 保存模型到指定路径
model_1.save("/tmp/my_saved_model")
INFO:tensorflow:Assets written to: /tmp/my_saved_model/assets
INFO:tensorflow:Assets written to: /tmp/my_saved_model/assets
7. 绘制模型
绘制决策树并跟随第一分支有助于了解决策森林。在某些情况下,绘制模型甚至可以用于调试。
由于它们训练的方式不同,某些模型比其他模型更有趣。由于训练过程中注入的噪声和树的深度,绘制随机森林的信息量较少,而绘制CART或梯度提升树的第一棵树更具信息量。
尽管如此,让我们绘制我们的随机森林模型的第一棵树:
# 使用model_plotter模块中的plot_model_in_colab函数绘制模型图
# 参数model_1为要绘制的模型
# 参数tree_idx为要绘制的决策树的索引,这里选择第一个决策树
# 参数max_depth为要绘制的决策树的最大深度,这里设置为3
tfdf.model_plotter.plot_model_in_colab(model_1, tree_idx=0, max_depth=3)
/**
- Plotting of decision trees generated by TF-DF.
- A tree is a recursive structure of node objects.
- A node contains one or more of the following components:
-
- A value: Representing the output of the node. If the node is not a leaf,
-
the value is only present for analysis i.e. it is not used for
-
predictions.
-
- A condition : For non-leaf nodes, the condition (also known as split)
-
defines a binary test to branch to the positive or negative child.
-
- An explanation: Generally a plot showing the relation between the label
-
and the condition to give insights about the effect of the condition.
-
- Two children : For non-leaf nodes, the children nodes. The first
-
children (i.e. "node.children[0]") is the negative children (drawn in
-
red). The second children is the positive one (drawn in green).
*/
/**
- Plots a single decision tree into a DOM element.
- @param {!options} options Dictionary of configurations.
- @param {!tree} raw_tree Recursive tree structure.
- @param {string} canvas_id Id of the output dom element.
*/
function display_tree(options, raw_tree, canvas_id) {
console.log(options);
// Determine the node placement.
const tree_struct = d3.tree().nodeSize(
[options.node_y_offset, options.node_x_offset])(d3.hierarchy(raw_tree));
// Boundaries of the node placement.
let x_min = Infinity;
let x_max = -x_min;
let y_min = Infinity;
let y_max = -x_min;
tree_struct.each(d => {
if (d.x > x_max) x_max = d.x;
if (d.x < x_min) x_min = d.x;
if (d.y > y_max) y_max = d.y;
if (d.y < y_min) y_min = d.y;
});
// Size of the plot.
const width = y_max - y_min + options.node_x_size + options.margin * 2;
const height = x_max - x_min + options.node_y_size + options.margin * 2 +
options.node_y_offset - options.node_y_size;
const plot = d3.select(canvas_id);
// Tool tip
options.tooltip = plot.append(‘div’)
.attr(‘width’, 100)
.attr(‘height’, 100)
.style(‘padding’, ‘4px’)
.style(‘background’, ‘#fff’)
.style(‘box-shadow’, ‘4px 4px 0px rgba(0,0,0,0.1)’)
.style(‘border’, ‘1px solid black’)
.style(‘font-family’, ‘sans-serif’)
.style(‘font-size’, options.font_size)
.style(‘position’, ‘absolute’)
.style(‘z-index’, ‘10’)
.attr(‘pointer-events’, ‘none’)
.style(‘display’, ‘none’);
// Create canvas
const svg = plot.append(‘svg’).attr(‘width’, width).attr(‘height’, height);
const graph =
svg.style(‘overflow’, ‘visible’)
.append(‘g’)
.attr(‘font-family’, ‘sans-serif’)
.attr(‘font-size’, options.font_size)
.attr(
‘transform’,
() => translate(${options.margin},${ - x_min + options.node_y_offset / 2 + options.margin})
);
// Plot bounding box.
if (options.show_plot_bounding_box) {
svg.append(‘rect’)
.attr(‘width’, width)
.attr(‘height’, height)
.attr(‘fill’, ‘none’)
.attr(‘stroke-width’, 1.0)
.attr(‘stroke’, ‘black’);
}
// Draw the edges.
display_edges(options, graph, tree_struct);
// Draw the nodes.
display_nodes(options, graph, tree_struct);
}
/**
- Draw the nodes of the tree.
- @param {!options} options Dictionary of configurations.
- @param {!graph} graph D3 search handle containing the graph.
- @param {!tree_struct} tree_struct Structure of the tree (node placement,
-
data, etc.).
*/
function display_nodes(options, graph, tree_struct) {
const nodes = graph.append(‘g’)
.selectAll(‘g’)
.data(tree_struct.descendants())
.join(‘g’)
.attr(‘transform’, d => translate(${d.y},${d.x})
);
nodes.append(‘rect’)
.attr(‘x’, 0.5)
.attr(‘y’, 0.5)
.attr(‘width’, options.node_x_size)
.attr(‘height’, options.node_y_size)
.attr(‘stroke’, ‘lightgrey’)
.attr(‘stroke-width’, 1)
.attr(‘fill’, ‘white’)
.attr(‘y’, -options.node_y_size / 2);
// Brackets on the right of condition nodes without children.
non_leaf_node_without_children =
nodes.filter(node => node.data.condition != null && node.children == null)
.append(‘g’)
.attr(‘transform’, translate(${options.node_x_size},0)
);
non_leaf_node_without_children.append(‘path’)
.attr(‘d’, ‘M0,0 C 10,0 0,10 10,10’)
.attr(‘fill’, ‘none’)
.attr(‘stroke-width’, 1.0)
.attr(‘stroke’, ‘#F00’);
non_leaf_node_without_children.append(‘path’)
.attr(‘d’, ‘M0,0 C 10,0 0,-10 10,-10’)
.attr(‘fill’, ‘none’)
.attr(‘stroke-width’, 1.0)
.attr(‘stroke’, ‘#0F0’);
const node_content = nodes.append(‘g’).attr(
‘transform’,
translate(0,${options.node_padding - options.node_y_size / 2})
);
node_content.append(node => create_node_element(options, node));
}
/**
- Creates the D3 content for a single node.
- @param {!options} options Dictionary of configurations.
- @param {!node} node Node to draw.
- @return {!d3} D3 content.
*/
function create_node_element(options, node) {
// Output accumulator.
let output = {
// Content to draw.
content: d3.create(‘svg:g’),
// Vertical offset to the next element to draw.
vertical_offset: 0
};
// Conditions.
if (node.data.condition != null) {
display_condition(options, node.data.condition, output);
}
// Values.
if (node.data.value != null) {
display_value(options, node.data.value, output);
}
// Explanations.
if (node.data.explanation != null) {
display_explanation(options, node.data.explanation, output);
}
return output.content.node();
}
/**
- Adds a single line of text inside of a node.
- @param {!options} options Dictionary of configurations.
- @param {string} text Text to display.
- @param {!output} output Output display accumulator.
*/
function display_node_text(options, text, output) {
output.content.append(‘text’)
.attr(‘x’, options.node_padding)
.attr(‘y’, output.vertical_offset)
.attr(‘alignment-baseline’, ‘hanging’)
.text(text);
output.vertical_offset += 10;
}
/**
- Adds a single line of text inside of a node with a tooltip.
- @param {!options} options Dictionary of configurations.
- @param {string} text Text to display.
- @param {string} tooltip Text in the Tooltip.
- @param {!output} output Output display accumulator.
*/
function display_node_text_with_tooltip(options, text, tooltip, output) {
const item = output.content.append(‘text’)
.attr(‘x’, options.node_padding)
.attr(‘alignment-baseline’, ‘hanging’)
.text(text);
add_tooltip(options, item, () => tooltip);
output.vertical_offset += 10;
}
/**
- Adds a tooltip to a dom element.
- @param {!options} options Dictionary of configurations.
- @param {!dom} target Dom element to equip with a tooltip.
- @param {!func} get_content Generates the html content of the tooltip.
*/
function add_tooltip(options, target, get_content) {
function show(d) {
options.tooltip.style(‘display’, ‘block’);
options.tooltip.html(get_content());
}
function hide(d) {
options.tooltip.style(‘display’, ‘none’);
}
function move(d) {
options.tooltip.style(‘display’, ‘block’);
options.tooltip.style(‘left’, (d.pageX + 5) + ‘px’);
options.tooltip.style(‘top’, d.pageY + ‘px’);
}
target.on(‘mouseover’, show);
target.on(‘mouseout’, hide);
target.on(‘mousemove’, move);
}
/**
- Adds a condition inside of a node.
- @param {!options} options Dictionary of configurations.
- @param {!condition} condition Condition to display.
- @param {!output} output Output display accumulator.
*/
function display_condition(options, condition, output) {
threshold_format = d3.format(‘r’);
if (condition.type === ‘IS_MISSING’) {
display_node_text(options, ${condition.attribute} is missing
, output);
return;
}
if (condition.type === ‘IS_TRUE’) {
display_node_text(options, ${condition.attribute} is true
, output);
return;
}
if (condition.type === ‘NUMERICAL_IS_HIGHER_THAN’) {
format = d3.format(‘r’);
display_node_text(
options,
${condition.attribute} >= ${threshold_format(condition.threshold)}
,
output);
return;
}
if (condition.type === ‘CATEGORICAL_IS_IN’) {
display_node_text_with_tooltip(
options, ${condition.attribute} in [...]
,
${condition.attribute} in [${condition.mask}]
, output);
return;
}
if (condition.type === ‘CATEGORICAL_SET_CONTAINS’) {
display_node_text_with_tooltip(
options, ${condition.attribute} intersect [...]
,
${condition.attribute} intersect [${condition.mask}]
, output);
return;
}
if (condition.type === ‘NUMERICAL_SPARSE_OBLIQUE’) {
display_node_text_with_tooltip(
options, Sparse oblique split...
,
[${condition.attributes}]*[${condition.weights}]>=${ threshold_format(condition.threshold)}
,
output);
return;
}
display_node_text(
options, Non supported condition ${condition.type}
, output);
}
/**
-
Adds a value inside of a node.
-
@param {!options} options Dictionary of configurations.
-
@param {!value} value Value to display.
-
@param {!output} output Output display accumulator.
*/
function display_value(options, value, output) {
if (value.type === ‘PROBABILITY’) {
const left_margin = 0;
const right_margin = 50;
const plot_width = options.node_x_size - options.node_padding * 2 -
left_margin - right_margin;let cusum = Array.from(d3.cumsum(value.distribution));
cusum.unshift(0);
const distribution_plot = output.content.append(‘g’).attr(
‘transform’,translate(0,${output.vertical_offset + 0.5})
);distribution_plot.selectAll(‘rect’)
.data(value.distribution)
.join(‘rect’)
.attr(‘height’, 10)
.attr(
‘x’,
(d, i) =>
(cusum[i] * plot_width + left_margin + options.node_padding))
.attr(‘width’, (d, i) => d * plot_width)
.style(‘fill’, (d, i) => d3.schemeSet1[i]);const num_examples =
output.content.append(‘g’)
.attr(‘transform’,translate(0,${output.vertical_offset})
)
.append(‘text’)
.attr(‘x’, options.node_x_size - options.node_padding)
.attr(‘alignment-baseline’, ‘hanging’)
.attr(‘text-anchor’, ‘end’)
.text((${value.num_examples})
);const distribution_details = d3.create(‘ul’);
distribution_details.selectAll(‘li’)
.data(value.distribution)
.join(‘li’)
.append(‘span’)
.text(
(d, i) =>
‘class ’ + i + ‘: ’ + d3.format(’.3%’)(value.distribution[i]));add_tooltip(options, distribution_plot, () => distribution_details.html());
add_tooltip(options, num_examples, () => ‘Number of examples’);output.vertical_offset += 10;
return;
}
if (value.type === ‘REGRESSION’) {
display_node_text(
options,
‘value: ’ + d3.format(‘r’)(value.value) + (
+
d3.format(’.6’)(value.num_examples) + )
,
output);
return;
}
if (value.type === ‘UPLIFT’) {
display_node_text(
options,
‘effect: ’ + d3.format(‘r’)(value.treatment_effect) + (
+
d3.format(’.6’)(value.num_examples) + )
,
output);
return;
}
display_node_text(options, Non supported value ${value.type}
, output);
}
/**
- Adds an explanation inside of a node.
- @param {!options} options Dictionary of configurations.
- @param {!explanation} explanation Explanation to display.
- @param {!output} output Output display accumulator.
*/
function display_explanation(options, explanation, output) {
// Margin before the explanation.
output.vertical_offset += 10;
display_node_text(
options, Non supported explanation ${explanation.type}
, output);
}
/**
- Draw the edges of the tree.
- @param {!options} options Dictionary of configurations.
- @param {!graph} graph D3 search handle containing the graph.
- @param {!tree_struct} tree_struct Structure of the tree (node placement,
-
data, etc.).
*/
function display_edges(options, graph, tree_struct) {
// Draw an edge between a parent and a child node with a bezier.
function draw_single_edge(d) {
return ‘M’ + (d.source.y + options.node_x_size) + ‘,’ + d.source.x + ’ C’ +
(d.source.y + options.node_x_size + options.edge_rounding) + ‘,’ +
d.source.x + ’ ’ + (d.target.y - options.edge_rounding) + ‘,’ +
d.target.x + ’ ’ + d.target.y + ‘,’ + d.target.x;
}
graph.append(‘g’)
.attr(‘fill’, ‘none’)
.attr(‘stroke-width’, 1.2)
.selectAll(‘path’)
.data(tree_struct.links())
.join(‘path’)
.attr(‘d’, draw_single_edge)
.attr(
‘stroke’, d => (d.target === d.source.children[0]) ? ‘#0F0’ : ‘#F00’);
}
display_tree({“margin”: 10, “node_x_size”: 160, “node_y_size”: 28, “node_x_offset”: 180, “node_y_offset”: 33, “font_size”: 10, “edge_rounding”: 20, “node_padding”: 2, “show_plot_bounding_box”: false}, {“value”: {“type”: “PROBABILITY”, “distribution”: [0.4435146443514644, 0.34309623430962344, 0.21338912133891214], “num_examples”: 239.0}, “condition”: {“type”: “NUMERICAL_IS_HIGHER_THAN”, “attribute”: “flipper_length_mm”, “threshold”: 206.5}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [0.0, 0.9534883720930233, 0.046511627906976744], “num_examples”: 86.0}, “condition”: {“type”: “NUMERICAL_IS_HIGHER_THAN”, “attribute”: “bill_depth_mm”, “threshold”: 17.200000762939453}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [0.0, 0.2, 0.8], “num_examples”: 5.0}}, {“value”: {“type”: “PROBABILITY”, “distribution”: [0.0, 1.0, 0.0], “num_examples”: 81.0}}]}, {“value”: {“type”: “PROBABILITY”, “distribution”: [0.6928104575163399, 0.0, 0.30718954248366015], “num_examples”: 153.0}, “condition”: {“type”: “CATEGORICAL_IS_IN”, “attribute”: “island”, “mask”: [“Biscoe”, “Torgersen”]}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [1.0, 0.0, 0.0], “num_examples”: 81.0}}, {“value”: {“type”: “PROBABILITY”, “distribution”: [0.3472222222222222, 0.0, 0.6527777777777778], “num_examples”: 72.0}, “condition”: {“type”: “NUMERICAL_IS_HIGHER_THAN”, “attribute”: “bill_length_mm”, “threshold”: 42.30000305175781}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [0.0, 0.0, 1.0], “num_examples”: 47.0}}, {“value”: {“type”: “PROBABILITY”, “distribution”: [1.0, 0.0, 0.0], “num_examples”: 25.0}}]}]}]}, “#tree_plot_28efd97aa2df4edca61ad38bf7763da0”)
左侧的根节点包含第一个条件(bill_depth_mm >= 16.55
),示例数量(240)和标签分布(红蓝绿色条形图)。
满足bill_depth_mm >= 16.55
条件的示例分支到绿色路径。其他示例分支到红色路径。
节点越深,它们变得越“纯净”,即标签分布偏向于某个类别的子集。
**注意:**将鼠标悬停在图表上以获取详细信息。
8. 模型结构和特征重要性
模型的整体结构可以通过.summary()
来展示。您将会看到以下内容:
- 类型:用于训练模型的学习算法(在我们的案例中为
随机森林
)。 - 任务:模型解决的问题(在我们的案例中为
分类
)。 - 输入特征:模型的输入特征。
- 变量重要性:每个特征对于模型的重要性的不同度量。
- 袋外评估:模型的袋外评估。这是一种廉价且高效的交叉验证替代方法。
- {树,节点}数量和其他指标:关于决策森林结构的统计信息。
备注:摘要的内容取决于学习算法(例如,袋外评估仅适用于随机森林)和超参数(例如,超参数中的平均准确率下降变量重要性可以禁用)。
# 设置单元格高度为300
%set_cell_height 300
# 打印模型1的概要信息
model_1.summary()
<IPython.core.display.Javascript object>
Model: "random_forest_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
=================================================================
Total params: 1 (1.00 Byte)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 1 (1.00 Byte)
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"
Input Features (7):
bill_depth_mm
bill_length_mm
body_mass_g
flipper_length_mm
island
sex
year
No weights
Variable Importance: INV_MEAN_MIN_DEPTH:
1. "flipper_length_mm" 0.459730 ################
2. "bill_length_mm" 0.428357 #############
3. "bill_depth_mm" 0.318034 #####
4. "island" 0.302253 ####
5. "body_mass_g" 0.270350 ##
6. "sex" 0.239199
7. "year" 0.238541
Variable Importance: NUM_AS_ROOT:
1. "flipper_length_mm" 161.000000 ################
2. "bill_length_mm" 66.000000 ######
3. "bill_depth_mm" 55.000000 ####
4. "body_mass_g" 10.000000
5. "island" 8.000000
Variable Importance: NUM_NODES:
1. "bill_length_mm" 686.000000 ################
2. "bill_depth_mm" 411.000000 #########
3. "flipper_length_mm" 357.000000 ########
4. "body_mass_g" 291.000000 ######
5. "island" 238.000000 #####
6. "sex" 21.000000
7. "year" 14.000000
Variable Importance: SUM_SCORE:
1. "flipper_length_mm" 26375.887035 ################
2. "bill_length_mm" 23387.499002 ##############
3. "bill_depth_mm" 9981.270101 ######
4. "island" 8813.632840 #####
5. "body_mass_g" 3264.050597 #
6. "sex" 101.269852
7. "year" 29.719130
Winner takes all: true
Out-of-bag evaluation: accuracy:0.958159 logloss:0.0945677
Number of trees: 300
Total number of nodes: 4336
Number of nodes by tree:
Count: 300 Average: 14.4533 StdDev: 2.95654
Min: 7 Max: 25 Ignored: 0
----------------------------------------------
[ 7, 8) 6 2.00% 2.00% #
[ 8, 9) 0 0.00% 2.00%
[ 9, 10) 9 3.00% 5.00% #
[ 10, 11) 0 0.00% 5.00%
[ 11, 12) 36 12.00% 17.00% ####
[ 12, 13) 0 0.00% 17.00%
[ 13, 14) 92 30.67% 47.67% ##########
[ 14, 15) 0 0.00% 47.67%
[ 15, 16) 73 24.33% 72.00% ########
[ 16, 17) 0 0.00% 72.00%
[ 17, 18) 48 16.00% 88.00% #####
[ 18, 19) 0 0.00% 88.00%
[ 19, 20) 26 8.67% 96.67% ###
[ 20, 21) 0 0.00% 96.67%
[ 21, 22) 8 2.67% 99.33% #
[ 22, 23) 0 0.00% 99.33%
[ 23, 24) 1 0.33% 99.67%
[ 24, 25) 0 0.00% 99.67%
[ 25, 25] 1 0.33% 100.00%
Depth by leafs:
Count: 2318 Average: 3.27653 StdDev: 1.02213
Min: 1 Max: 7 Ignored: 0
----------------------------------------------
[ 1, 2) 20 0.86% 0.86%
[ 2, 3) 563 24.29% 25.15% #######
[ 3, 4) 787 33.95% 59.10% ##########
[ 4, 5) 704 30.37% 89.47% #########
[ 5, 6) 200 8.63% 98.10% ###
[ 6, 7) 36 1.55% 99.65%
[ 7, 7] 8 0.35% 100.00%
Number of training obs by leaf:
Count: 2318 Average: 30.9318 StdDev: 32.1481
Min: 5 Max: 110 Ignored: 0
----------------------------------------------
[ 5, 10) 1143 49.31% 49.31% ##########
[ 10, 15) 88 3.80% 53.11% #
[ 15, 20) 81 3.49% 56.60% #
[ 20, 26) 78 3.36% 59.97% #
[ 26, 31) 74 3.19% 63.16% #
[ 31, 36) 81 3.49% 66.65% #
[ 36, 42) 103 4.44% 71.10% #
[ 42, 47) 46 1.98% 73.08%
[ 47, 52) 34 1.47% 74.55%
[ 52, 58) 20 0.86% 75.41%
[ 58, 63) 30 1.29% 76.70%
[ 63, 68) 39 1.68% 78.39%
[ 68, 73) 58 2.50% 80.89% #
[ 73, 79) 65 2.80% 83.69% #
[ 79, 84) 98 4.23% 87.92% #
[ 84, 89) 93 4.01% 91.93% #
[ 89, 95) 98 4.23% 96.16% #
[ 95, 100) 57 2.46% 98.62%
[ 100, 105) 25 1.08% 99.70%
[ 105, 110] 7 0.30% 100.00%
Attribute in nodes:
686 : bill_length_mm [NUMERICAL]
411 : bill_depth_mm [NUMERICAL]
357 : flipper_length_mm [NUMERICAL]
291 : body_mass_g [NUMERICAL]
238 : island [CATEGORICAL]
21 : sex [CATEGORICAL]
14 : year [NUMERICAL]
Attribute in nodes with depth <= 0:
161 : flipper_length_mm [NUMERICAL]
66 : bill_length_mm [NUMERICAL]
55 : bill_depth_mm [NUMERICAL]
10 : body_mass_g [NUMERICAL]
8 : island [CATEGORICAL]
Attribute in nodes with depth <= 1:
258 : flipper_length_mm [NUMERICAL]
252 : bill_length_mm [NUMERICAL]
181 : bill_depth_mm [NUMERICAL]
132 : island [CATEGORICAL]
57 : body_mass_g [NUMERICAL]
Attribute in nodes with depth <= 2:
460 : bill_length_mm [NUMERICAL]
318 : bill_depth_mm [NUMERICAL]
317 : flipper_length_mm [NUMERICAL]
207 : island [CATEGORICAL]
172 : body_mass_g [NUMERICAL]
3 : sex [CATEGORICAL]
Attribute in nodes with depth <= 3:
631 : bill_length_mm [NUMERICAL]
390 : bill_depth_mm [NUMERICAL]
341 : flipper_length_mm [NUMERICAL]
265 : body_mass_g [NUMERICAL]
234 : island [CATEGORICAL]
14 : sex [CATEGORICAL]
9 : year [NUMERICAL]
Attribute in nodes with depth <= 5:
683 : bill_length_mm [NUMERICAL]
411 : bill_depth_mm [NUMERICAL]
357 : flipper_length_mm [NUMERICAL]
290 : body_mass_g [NUMERICAL]
238 : island [CATEGORICAL]
21 : sex [CATEGORICAL]
14 : year [NUMERICAL]
Condition type in nodes:
1759 : HigherCondition
259 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
292 : HigherCondition
8 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
748 : HigherCondition
132 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
1267 : HigherCondition
210 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
1636 : HigherCondition
248 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
1755 : HigherCondition
259 : ContainsBitmapCondition
Node format: NOT_SET
Training OOB:
trees: 1, Out-of-bag evaluation: accuracy:0.943182 logloss:2.04793
trees: 11, Out-of-bag evaluation: accuracy:0.949367 logloss:0.383614
trees: 21, Out-of-bag evaluation: accuracy:0.953975 logloss:0.386135
trees: 35, Out-of-bag evaluation: accuracy:0.953975 logloss:0.249595
trees: 50, Out-of-bag evaluation: accuracy:0.949791 logloss:0.249004
trees: 62, Out-of-bag evaluation: accuracy:0.949791 logloss:0.247371
trees: 73, Out-of-bag evaluation: accuracy:0.962343 logloss:0.246108
trees: 83, Out-of-bag evaluation: accuracy:0.958159 logloss:0.240771
trees: 96, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0994905
trees: 106, Out-of-bag evaluation: accuracy:0.966527 logloss:0.100095
trees: 117, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0959006
trees: 127, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0962165
trees: 138, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0927663
trees: 148, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0931921
trees: 158, Out-of-bag evaluation: accuracy:0.966527 logloss:0.092117
trees: 170, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0926436
trees: 180, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0927239
trees: 190, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0942833
trees: 200, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0941766
trees: 210, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0938748
trees: 220, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0950461
trees: 231, Out-of-bag evaluation: accuracy:0.953975 logloss:0.0951599
trees: 241, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0948531
trees: 251, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0942377
trees: 262, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0940229
trees: 272, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0934476
trees: 282, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0934649
trees: 292, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0943068
trees: 300, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0945677
信息在summary
中都可以通过模型检查器以编程方式获取:
# 获取模型的输入特征列表
features = model_1.make_inspector().features()
["bill_depth_mm" (1; #1),
"bill_length_mm" (1; #2),
"body_mass_g" (1; #3),
"flipper_length_mm" (1; #4),
"island" (4; #5),
"sex" (4; #6),
"year" (1; #7)]
# 打印变量重要性
model_1.make_inspector().variable_importances()
{'NUM_AS_ROOT': [("flipper_length_mm" (1; #4), 161.0),
("bill_length_mm" (1; #2), 66.0),
("bill_depth_mm" (1; #1), 55.0),
("body_mass_g" (1; #3), 10.0),
("island" (4; #5), 8.0)],
'SUM_SCORE': [("flipper_length_mm" (1; #4), 26375.887034731917),
("bill_length_mm" (1; #2), 23387.499002089724),
("bill_depth_mm" (1; #1), 9981.270100556314),
("island" (4; #5), 8813.63283989951),
("body_mass_g" (1; #3), 3264.0505972094834),
("sex" (4; #6), 101.26985213905573),
("year" (1; #7), 29.719129994511604)],
'NUM_NODES': [("bill_length_mm" (1; #2), 686.0),
("bill_depth_mm" (1; #1), 411.0),
("flipper_length_mm" (1; #4), 357.0),
("body_mass_g" (1; #3), 291.0),
("island" (4; #5), 238.0),
("sex" (4; #6), 21.0),
("year" (1; #7), 14.0)],
'INV_MEAN_MIN_DEPTH': [("flipper_length_mm" (1; #4), 0.4597295587756743),
("bill_length_mm" (1; #2), 0.42835670851367663),
("bill_depth_mm" (1; #1), 0.31803398397339727),
("island" (4; #5), 0.30225257091871593),
("body_mass_g" (1; #3), 0.27035044480247944),
("sex" (4; #6), 0.23919881592559233),
("year" (1; #7), 0.23854067913913543)]}
概述和检查器的内容取决于学习算法(在本例中为tfdf.keras.RandomForestModel
)及其超参数(例如,compute_oob_variable_importances=True
将触发计算随机森林学习器的Out-of-bag变量重要性)。
9. 模型自我评估
在训练TFDF模型时,即使没有提供验证数据集给fit()
方法,模型也可以进行自我评估。具体的逻辑取决于模型。例如,随机森林将使用袋外评估,而梯度提升树将使用内部的训练验证。
**注意:**虽然这个评估是在训练期间计算的,但它并不是在训练数据集上计算的,因此可能是低质量的评估。
可以通过检查器的evaluation()
方法获取模型的自我评估结果。
# 创建一个名为model_1的模型,并调用make_inspector()方法创建一个检查器对象,调用检查器对象的evaluation()方法,对模型进行评估
model_1.make_inspector().evaluation()
Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09456771872859744, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)
10. 绘制训练日志
训练日志显示了模型的质量(例如在袋外或验证数据集上评估的准确率)与模型中树的数量之间的关系。这些日志有助于研究模型大小和模型质量之间的平衡。
日志可以通过多种方式获取:
- 如果
fit()
被包装在with sys_pipes():
中,则在训练期间显示(参见上面的示例)。 - 在模型摘要的末尾,即
model.summary()
(参见上面的示例)。 - 通过编程方式,使用模型检查器,即
model.make_inspector().training_logs()
。 - 使用TensorBoard
让我们尝试选项2和3:
# 创建一个名为model_1的模型,并调用make_inspector()方法创建一个检查器
# 调用training_logs()方法获取模型的训练日志
# 设置单元格高度为150
<IPython.core.display.Javascript object>
[TrainLog(num_trees=1, evaluation=Evaluation(num_examples=88, accuracy=0.9431818181818182, loss=2.04793474890969, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=11, evaluation=Evaluation(num_examples=237, accuracy=0.9493670886075949, loss=0.3836141189693902, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=21, evaluation=Evaluation(num_examples=239, accuracy=0.9539748953974896, loss=0.38613533478027606, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=35, evaluation=Evaluation(num_examples=239, accuracy=0.9539748953974896, loss=0.24959545451602178, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=50, evaluation=Evaluation(num_examples=239, accuracy=0.9497907949790795, loss=0.2490036289936329, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=62, evaluation=Evaluation(num_examples=239, accuracy=0.9497907949790795, loss=0.24737085921058594, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=73, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.24610795769606675, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=83, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.24077113418524235, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=96, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.0994904973703947, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=106, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.1000949550326524, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=117, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09590058801033258, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=127, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09621651767593298, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=138, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09276632447123029, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=148, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09319210400859432, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=158, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09211699942041142, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=170, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09264358151002658, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=180, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09272387361925516, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=190, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.0942832787314656, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=200, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09417655552390604, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=210, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09387483396353083, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=220, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.0950461220674248, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=231, evaluation=Evaluation(num_examples=239, accuracy=0.9539748953974896, loss=0.09515991921548314, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=241, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09485313651701396, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=251, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09423767419134473, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=262, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09402294695439697, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=272, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09344756691307453, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=282, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.0934649518804009, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=292, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09430678192307884, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
TrainLog(num_trees=300, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09456771872859744, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None))]
让我们来绘制它:
# 导入matplotlib.pyplot模块,用于绘图
import matplotlib.pyplot as plt
# 获取模型的训练日志
logs = model_1.make_inspector().training_logs()
# 创建一个图形窗口,设置图形窗口的大小为12x4
plt.figure(figsize=(12, 4))
# 在图形窗口中创建一个子图,子图的位置为1行2列中的第1个位置
plt.subplot(1, 2, 1)
# 绘制折线图,x轴为每个日志中的树的数量,y轴为每个日志中的准确率
plt.plot([log.num_trees for log in logs], [log.evaluation.accuracy for log in logs])
# 设置x轴的标签为"Number of trees"
plt.xlabel("Number of trees")
# 设置y轴的标签为"Accuracy (out-of-bag)"
plt.ylabel("Accuracy (out-of-bag)")
# 在图形窗口中创建一个子图,子图的位置为1行2列中的第2个位置
plt.subplot(1, 2, 2)
# 绘制折线图,x轴为每个日志中的树的数量,y轴为每个日志中的损失值
plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])
# 设置x轴的标签为"Number of trees"
plt.xlabel("Number of trees")
# 设置y轴的标签为"Logloss (out-of-bag)"
plt.ylabel("Logloss (out-of-bag)")
# 显示图形窗口中的图形
plt.show()
这个数据集很小。你可以看到模型几乎立即收敛。
让我们使用TensorBoard:
# 加载TensorBoard的notebook扩展
%load_ext tensorboard
# 加载Google内部版本的TensorBoard的notebook扩展
# %load_ext google3.learning.brain.tensorboard.notebook.extension
# 清除已有的结果(如果有的话)
!rm -fr "/tmp/tensorboard_logs"
# Export the meta-data to tensorboard.
# 将元数据导出到tensorboard。
model_1.make_inspector().export_to_tensorboard("/tmp/tensorboard_logs")
# 导入的包和模块
# 无需导入任何包和模块
# 设置tensorboard的日志目录为"/tmp/tensorboard_logs"
# %tensorboard是一个魔术命令,用于启动一个tensorboard实例
# --logdir参数指定了tensorboard的日志目录
%tensorboard --logdir "/tmp/tensorboard_logs"
11. 使用不同的学习算法重新训练模型
学习算法由模型类定义。例如,tfdf.keras.RandomForestModel()
训练随机森林,而tfdf.keras.GradientBoostedTreesModel()
训练梯度提升决策树。
可以通过调用tfdf.keras.get_all_models()
或在学习器列表中列出学习算法。
# 获取所有可用的模型
tfdf.keras.get_all_models()
[tensorflow_decision_forests.keras.RandomForestModel,
tensorflow_decision_forests.keras.GradientBoostedTreesModel,
tensorflow_decision_forests.keras.CartModel,
tensorflow_decision_forests.keras.DistributedGradientBoostedTreesModel]
学习算法的描述以及它们的超参数也可以在API参考和内置帮助中找到:
# 在任何地方都可以使用help函数来获取帮助信息。
help(tfdf.keras.RandomForestModel)
# 在ipython或notebook中,可以使用?来获取帮助信息,通常会在一个单独的面板中打开。
tfdf.keras.RandomForestModel?
Help on class RandomForestModel in module tensorflow_decision_forests.keras:
class RandomForestModel(tensorflow_decision_forests.keras.wrappers.RandomForestModel)
| RandomForestModel(*args, **kwargs)
|
| Method resolution order:
| RandomForestModel
| tensorflow_decision_forests.keras.wrappers.RandomForestModel
| tensorflow_decision_forests.keras.core.CoreModel
| tensorflow_decision_forests.keras.core_inference.InferenceCoreModel
| keras.src.engine.training.Model
| keras.src.engine.base_layer.Layer
| tensorflow.python.module.module.Module
| tensorflow.python.trackable.autotrackable.AutoTrackable
| tensorflow.python.trackable.base.Trackable
| keras.src.utils.version_utils.LayerVersionSelector
| keras.src.utils.version_utils.ModelVersionSelector
| builtins.object
|
| Methods inherited from tensorflow_decision_forests.keras.wrappers.RandomForestModel:
|
| __init__(self, task: Optional[ForwardRef('abstract_model_pb2.Task')] = 1, features: Optional[List[tensorflow_decision_forests.keras.core.FeatureUsage]] = None, exclude_non_specified_features: Optional[bool] = False, preprocessing: Optional[ForwardRef('tf.keras.models.Functional')] = None, postprocessing: Optional[ForwardRef('tf.keras.models.Functional')] = None, ranking_group: Optional[str] = None, uplift_treatment: Optional[str] = None, temp_directory: Optional[str] = None, verbose: int = 1, hyperparameter_template: Optional[str] = None, advanced_arguments: Optional[tensorflow_decision_forests.keras.core_inference.AdvancedArguments] = None, num_threads: Optional[int] = None, name: Optional[str] = None, max_vocab_count: Optional[int] = 2000, try_resume_training: Optional[bool] = True, check_dataset: Optional[bool] = True, tuner: Optional[tensorflow_decision_forests.component.tuner.tuner.Tuner] = None, discretize_numerical_features: bool = False, num_discretized_numerical_bins: int = 255, multitask: Optional[List[tensorflow_decision_forests.keras.core_inference.MultiTaskItem]] = None, adapt_bootstrap_size_ratio_for_maximum_training_duration: Optional[bool] = False, allow_na_conditions: Optional[bool] = False, bootstrap_size_ratio: Optional[float] = 1.0, bootstrap_training_dataset: Optional[bool] = True, categorical_algorithm: Optional[str] = 'CART', categorical_set_split_greedy_sampling: Optional[float] = 0.1, categorical_set_split_max_num_items: Optional[int] = -1, categorical_set_split_min_item_frequency: Optional[int] = 1, compute_oob_performances: Optional[bool] = True, compute_oob_variable_importances: Optional[bool] = False, growing_strategy: Optional[str] = 'LOCAL', honest: Optional[bool] = False, honest_fixed_separation: Optional[bool] = False, honest_ratio_leaf_examples: Optional[float] = 0.5, in_split_min_examples_check: Optional[bool] = True, keep_non_leaf_label_distribution: Optional[bool] = True, max_depth: Optional[int] = 16, max_num_nodes: Optional[int] = None, maximum_model_size_in_memory_in_bytes: Optional[float] = -1.0, maximum_training_duration_seconds: Optional[float] = -1.0, min_examples: Optional[int] = 5, missing_value_policy: Optional[str] = 'GLOBAL_IMPUTATION', num_candidate_attributes: Optional[int] = 0, num_candidate_attributes_ratio: Optional[float] = -1.0, num_oob_variable_importances_permutations: Optional[int] = 1, num_trees: Optional[int] = 300, pure_serving_model: Optional[bool] = False, random_seed: Optional[int] = 123456, sampling_with_replacement: Optional[bool] = True, sorting_strategy: Optional[str] = 'PRESORT', sparse_oblique_normalization: Optional[str] = None, sparse_oblique_num_projections_exponent: Optional[float] = None, sparse_oblique_projection_density_factor: Optional[float] = None, sparse_oblique_weights: Optional[str] = None, split_axis: Optional[str] = 'AXIS_ALIGNED', uplift_min_examples_in_treatment: Optional[int] = 5, uplift_split_score: Optional[str] = 'KULLBACK_LEIBLER', winner_take_all: Optional[bool] = True, explicit_args: Optional[Set[str]] = None)
|
| ----------------------------------------------------------------------
| Static methods inherited from tensorflow_decision_forests.keras.wrappers.RandomForestModel:
|
| capabilities() -> yggdrasil_decision_forests.learner.abstract_learner_pb2.LearnerCapabilities
| Lists the capabilities of the learning algorithm.
|
| predefined_hyperparameters() -> List[tensorflow_decision_forests.keras.core.HyperParameterTemplate]
| Returns a better than default set of hyper-parameters.
|
| They can be used directly with the `hyperparameter_template` argument of the
| model constructor.
|
| These hyper-parameters outperform the default hyper-parameters (either
| generally or in specific scenarios). Like default hyper-parameters, existing
| pre-defined hyper-parameters cannot change.
|
| ----------------------------------------------------------------------
| Methods inherited from tensorflow_decision_forests.keras.core.CoreModel:
|
| collect_data_step(self, data, is_training_example)
| Collect examples e.g. training or validation.
|
| fit(self, x=None, y=None, callbacks=None, verbose: Optional[Any] = None, validation_steps: Optional[int] = None, validation_data: Optional[Any] = None, sample_weight: Optional[Any] = None, steps_per_epoch: Optional[Any] = None, class_weight: Optional[Any] = None, **kwargs) -> keras.src.callbacks.History
| Trains the model.
|
| Local training
| ==============
|
| It is recommended to use a Pandas Dataframe dataset and to convert it to
| a TensorFlow dataset with `pd_dataframe_to_tf_dataset()`:
| ```python
| pd_dataset = pandas.Dataframe(...)
| tf_dataset = pd_dataframe_to_tf_dataset(dataset, label="my_label")
| model.fit(pd_dataset)
| ```
|
| The following dataset formats are supported:
|
| 1. "x" is a `tf.data.Dataset` containing a tuple "(features, labels)".
| "features" can be a dictionary a tensor, a list of tensors or a
| dictionary of tensors (recommended). "labels" is a tensor.
|
| 2. "x" is a tensor, list of tensors or dictionary of tensors containing
| the input features. "y" is a tensor.
|
| 3. "x" is a numpy-array, list of numpy-arrays or dictionary of
| numpy-arrays containing the input features. "y" is a numpy-array.
|
| IMPORTANT: This model trains on the entire dataset at once. This has the
| following consequences:
|
| 1. The dataset need to be read exactly once. If you use a TensorFlow
| dataset, make sure NOT to add a "repeat" operation.
| 2. The algorithm does not benefit from shuffling the dataset. If you use a
| TensorFlow dataset, make sure NOT to add a "shuffle" operation.
| 3. The dataset needs to be batched (i.e. with a "batch" operation).
| However, the number of elements per batch has not impact on the model.
| Generally, it is recommended to use batches as large as possible as its
| speeds-up reading the dataset in TensorFlow.
|
| Input features do not need to be normalized (e.g. dividing numerical values
| by the variance) or indexed (e.g. replacing categorical string values by
| an integer). Additionally, missing values can be consumed natively.
|
| Distributed training
| ====================
|
| Some of the learning algorithms will support distributed training with the
| ParameterServerStrategy.
|
| In this case, the dataset is read asynchronously in between the workers. The
| distribution of the training depends on the learning algorithm.
|
| Like for non-distributed training, the dataset should be read exactly once.
| The simplest solution is to divide the dataset into different files (i.e.
| shards) and have each of the worker read a non overlapping subset of shards.
|
| IMPORTANT: The training dataset should not be infinite i.e. the training
| dataset should not contain any repeat operation.
|
| Currently (to be changed), the validation dataset (if provided) is simply
| feed to the `model.evaluate()` method. Therefore, it should satisfy Keras'
| evaluate API. Notably, for distributed training, the validation dataset
| should be infinite (i.e. have a repeat operation).
|
| See https://www.tensorflow.org/decision_forests/distributed_training for
| more details and examples.
|
| Here is a single example of distributed training using PSS for both dataset
| reading and training distribution.
|
| ```python
| def dataset_fn(context, paths, training=True):
| ds_path = tf.data.Dataset.from_tensor_slices(paths)
|
|
| if context is not None:
| # Train on at least 2 workers.
| current_worker = tfdf.keras.get_worker_idx_and_num_workers(context)
| assert current_worker.num_workers > 2
|
| # Split the dataset's examples among the workers.
| ds_path = ds_path.shard(
| num_shards=current_worker.num_workers,
| index=current_worker.worker_idx)
|
| def read_csv_file(path):
| numerical = tf.constant([math.nan], dtype=tf.float32)
| categorical_string = tf.constant([""], dtype=tf.string)
| csv_columns = [
| numerical, # age
| categorical_string, # workclass
| numerical, # fnlwgt
| ...
| ]
| column_names = [
| "age", "workclass", "fnlwgt", ...
| ]
| label_name = "label"
| return tf.data.experimental.CsvDataset(path, csv_columns, header=True)
|
| ds_columns = ds_path.interleave(read_csv_file)
|
| def map_features(*columns):
| assert len(column_names) == len(columns)
| features = {column_names[i]: col for i, col in enumerate(columns)}
| label = label_table.lookup(features.pop(label_name))
| return features, label
|
| ds_dataset = ds_columns.map(map_features)
| if not training:
| dataset = dataset.repeat(None)
| ds_dataset = ds_dataset.batch(batch_size)
| return ds_dataset
|
| strategy = tf.distribute.experimental.ParameterServerStrategy(...)
| sharded_train_paths = [list of dataset files]
| with strategy.scope():
| model = DistributedGradientBoostedTreesModel()
| train_dataset = strategy.distribute_datasets_from_function(
| lambda context: dataset_fn(context, sharded_train_paths))
|
| test_dataset = strategy.distribute_datasets_from_function(
| lambda context: dataset_fn(context, sharded_test_paths))
|
| model.fit(sharded_train_paths)
| evaluation = model.evaluate(test_dataset, steps=num_test_examples //
| batch_size)
| ```
|
| Args:
| x: Training dataset (See details above for the supported formats).
| y: Label of the training dataset. Only used if "x" does not contains the
| labels.
| callbacks: Callbacks triggered during the training. The training runs in a
| single epoch, itself run in a single step. Therefore, callback logic can
| be called equivalently before/after the fit function.
| verbose: Verbosity mode. 0 = silent, 1 = small details, 2 = full details.
| validation_steps: Number of steps in the evaluation dataset when
| evaluating the trained model with `model.evaluate()`. If not specified,
| evaluates the model on the entire dataset (generally recommended; not
| yet supported for distributed datasets).
| validation_data: Validation dataset. If specified, the learner might use
| this dataset to help training e.g. early stopping.
| sample_weight: Training weights. Note: training weights can also be
| provided as the third output in a `tf.data.Dataset` e.g. (features,
| label, weights).
| steps_per_epoch: [Parameter will be removed] Number of training batch to
| load before training the model. Currently, only supported for
| distributed training.
| class_weight: For binary classification only. Mapping class indices
| (integers) to a weight (float) value. Only available for non-Distributed
| training. For maximum compatibility, feed example weights through the
| tf.data.Dataset or using the `weight` argument of
| `pd_dataframe_to_tf_dataset`.
| **kwargs: Extra arguments passed to the core keras model's fit. Note that
| not all keras' model fit arguments are supported.
|
| Returns:
| A `History` object. Its `History.history` attribute is not yet
| implemented for decision forests algorithms, and will return empty.
| All other fields are filled as usual for `Keras.Mode.fit()`.
|
| fit_on_dataset_path(self, train_path: str, label_key: Optional[str] = None, weight_key: Optional[str] = None, valid_path: Optional[str] = None, dataset_format: Optional[str] = 'csv', max_num_scanned_rows_to_accumulate_statistics: Optional[int] = 100000, try_resume_training: Optional[bool] = True, input_model_signature_fn: Optional[Callable[[tensorflow_decision_forests.component.inspector.inspector.AbstractInspector], Any]] = <function build_default_input_model_signature at 0x7f823d6db040>, num_io_threads: int = 10)
| Trains the model on a dataset stored on disk.
|
| This solution is generally more efficient and easier than loading the
| dataset with a `tf.Dataset` both for local and distributed training.
|
| Usage example:
|
| # Local training
| ```python
| model = keras.GradientBoostedTreesModel()
| model.fit_on_dataset_path(
| train_path="/path/to/dataset.csv",
| label_key="label",
| dataset_format="csv")
| model.save("/model/path")
| ```
|
| # Distributed training
| ```python
| with tf.distribute.experimental.ParameterServerStrategy(...).scope():
| model = model = keras.DistributedGradientBoostedTreesModel()
| model.fit_on_dataset_path(
| train_path="/path/to/dataset@10",
| label_key="label",
| dataset_format="tfrecord+tfe")
| model.save("/model/path")
| ```
|
| Args:
| train_path: Path to the training dataset. Supports comma separated files,
| shard and glob notation.
| label_key: Name of the label column.
| weight_key: Name of the weighing column.
| valid_path: Path to the validation dataset. If not provided, or if the
| learning algorithm does not supports/needs a validation dataset,
| `valid_path` is ignored.
| dataset_format: Format of the dataset. Should be one of the registered
| dataset format (see [User
| Manual](https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#dataset-path-and-format)
| for more details). The format "csv" is always available but it is
| generally only suited for small datasets.
| max_num_scanned_rows_to_accumulate_statistics: Maximum number of examples
| to scan to determine the statistics of the features (i.e. the dataspec,
| e.g. mean value, dictionaries). (Currently) the "first" examples of the
| dataset are scanned (e.g. the first examples of the dataset is a single
| file). Therefore, it is important that the sampled dataset is relatively
| uniformly sampled, notably the scanned examples should contains all the
| possible categorical values (otherwise the not seen value will be
| treated as out-of-vocabulary). If set to None, the entire dataset is
| scanned. This parameter has no effect if the dataset is stored in a
| format that already contains those values.
| try_resume_training: If true, tries to resume training from the model
| checkpoint stored in the `temp_directory` directory. If `temp_directory`
| does not contain any model checkpoint, start the training from the
| start. Works in the following three situations: (1) The training was
| interrupted by the user (e.g. ctrl+c). (2) the training job was
| interrupted (e.g. rescheduling), ond (3) the hyper-parameter of the
| model were changed such that an initially completed training is now
| incomplete (e.g. increasing the number of trees).
| input_model_signature_fn: A lambda that returns the
| (Dense,Sparse,Ragged)TensorSpec (or structure of TensorSpec e.g.
| dictionary, list) corresponding to input signature of the model. If not
| specified, the input model signature is created by
| `build_default_input_model_signature`. For example, specify
| `input_model_signature_fn` if an numerical input feature (which is
| consumed as DenseTensorSpec(float32) by default) will be feed
| differently (e.g. RaggedTensor(int64)).
| num_io_threads: Number of threads to use for IO operations e.g. reading a
| dataset from disk. Increasing this value can speed-up IO operations when
| IO operations are either latency or cpu bounded.
|
| Returns:
| A `History` object. Its `History.history` attribute is not yet
| implemented for decision forests algorithms, and will return empty.
| All other fields are filled as usual for `Keras.Mode.fit()`.
|
| load_weights(self, *args, **kwargs)
| No-op for TensorFlow Decision Forests models.
|
| `load_weights` is not supported by TensorFlow Decision Forests models.
| To save and restore a model, use the SavedModel API i.e.
| `model.save(...)` and `tf.keras.models.load_model(...)`. To resume the
| training of an existing model, create the model with
| `try_resume_training=True` (default value) and with a similar
| `temp_directory` argument. See documentation of `try_resume_training`
| for more details.
|
| Args:
| *args: Passed through to base `keras.Model` implemenation.
| **kwargs: Passed through to base `keras.Model` implemenation.
|
| save(self, filepath: str, overwrite: Optional[bool] = True, **kwargs)
| Saves the model as a TensorFlow SavedModel.
|
| The exported SavedModel contains a standalone Yggdrasil Decision Forests
| model in the "assets" sub-directory. The Yggdrasil model can be used
| directly using the Yggdrasil API. However, this model does not contain the
| "preprocessing" layer (if any).
|
| Args:
| filepath: Path to the output model.
| overwrite: If true, override an already existing model. If false, raise an
| error if a model already exist.
| **kwargs: Arguments passed to the core keras model's save.
|
| support_distributed_training(self)
|
| train_on_batch(self, *args, **kwargs)
| No supported for Tensorflow Decision Forests models.
|
| Decision forests are not trained in batches the same way neural networks
| are. To avoid confusion, train_on_batch is disabled.
|
| Args:
| *args: Ignored
| **kwargs: Ignored.
|
| train_step(self, data)
| Collects training examples.
|
| valid_step(self, data)
| Collects validation examples.
|
| ----------------------------------------------------------------------
| Readonly properties inherited from tensorflow_decision_forests.keras.core.CoreModel:
|
| exclude_non_specified_features
| If true, only use the features specified in "features".
|
| learner
| Name of the learning algorithm used to train the model.
|
| learner_params
| Gets the dictionary of hyper-parameters passed in the model constructor.
|
| Changing this dictionary will impact the training.
|
| num_threads
| Number of threads used to train the model.
|
| num_training_examples
| Number of training examples.
|
| num_validation_examples
| Number of validation examples.
|
| training_model_id
| Identifier of the model.
|
| ----------------------------------------------------------------------
| Methods inherited from tensorflow_decision_forests.keras.core_inference.InferenceCoreModel:
|
| call(self, inputs, training=False)
| Inference of the model.
|
| This method is used for prediction and evaluation of a trained model.
|
| Args:
| inputs: Input tensors.
| training: Is the model being trained. Always False.
|
| Returns:
| Model predictions.
|
| call_get_leaves(self, inputs)
| Computes the index of the active leaf in each tree.
|
| The active leaf is the leave that that receive the example during inference.
|
| The returned value "leaves[i,j]" is the index of the active leave for the
| i-th example and the j-th tree. Leaves are indexed by depth first
| exploration with the negative child visited before the positive one
| (similarly as "iterate_on_nodes()" iteration). Leaf indices are also
| available with LeafNode.leaf_idx.
|
| Args:
| inputs: Input tensors. Same signature as the model's "call(inputs)".
|
| Returns:
| Index of the active leaf for each tree in the model.
|
| compile(self, metrics=None, weighted_metrics=None, **kwargs)
| Configure the model for training.
|
| Unlike for most Keras model, calling "compile" is optional before calling
| "fit".
|
| Args:
| metrics: List of metrics to be evaluated by the model during training and
| testing.
| weighted_metrics: List of metrics to be evaluated and weighted by
| `sample_weight` or `class_weight` during training and testing.
| **kwargs: Other arguments passed to compile.
|
| Raises:
| ValueError: Invalid arguments.
|
| make_inspector(self, index: int = 0) -> tensorflow_decision_forests.component.inspector.inspector.AbstractInspector
| Creates an inspector to access the internal model structure.
|
| Usage example:
|
| ```python
| inspector = model.make_inspector()
| print(inspector.num_trees())
| print(inspector.variable_importances())
| ```
|
| Args:
| index: Index of the sub-model. Only used for multitask models.
|
| Returns:
| A model inspector.
|
| make_predict_function(self)
| Prediction of the model (!= evaluation).
|
| make_test_function(self)
| Predictions for evaluation.
|
| predict_get_leaves(self, x)
| Gets the index of the active leaf of each tree.
|
| The active leaf is the leave that that receive the example during inference.
|
| The returned value "leaves[i,j]" is the index of the active leave for the
| i-th example and the j-th tree. Leaves are indexed by depth first
| exploration with the negative child visited before the positive one
| (similarly as "iterate_on_nodes()" iteration). Leaf indices are also
| available with LeafNode.leaf_idx.
|
| Args:
| x: Input samples as a tf.data.Dataset.
|
| Returns:
| Index of the active leaf for each tree in the model.
|
| ranking_group(self) -> Optional[str]
|
| summary(self, line_length=None, positions=None, print_fn=None)
| Shows information about the model.
|
| uplift_treatment(self) -> Optional[str]
|
| yggdrasil_model_path_tensor(self, multitask_model_index: int = 0) -> Optional[tensorflow.python.framework.ops.Tensor]
| Gets the path to yggdrasil model, if available.
|
| The effective path can be obtained with:
|
| ```python
| yggdrasil_model_path_tensor().numpy().decode("utf-8")
| ```
|
| Args:
| multitask_model_index: Index of the sub-model. Only used for multitask
| models.
|
| Returns:
| Path to the Yggdrasil model.
|
| yggdrasil_model_prefix(self, index: int = 0) -> str
| Gets the prefix of the internal yggdrasil model.
|
| ----------------------------------------------------------------------
| Readonly properties inherited from tensorflow_decision_forests.keras.core_inference.InferenceCoreModel:
|
| multitask
| Tasks to solve.
|
| task
| Task to solve (e.g. CLASSIFICATION, REGRESSION, RANKING).
|
| ----------------------------------------------------------------------
| Methods inherited from keras.src.engine.training.Model:
|
| __call__(self, *args, **kwargs)
|
| __copy__(self)
|
| __deepcopy__(self, memo)
|
| __reduce__(self)
| Helper for pickle.
|
| __setattr__(self, name, value)
| Support self.foo = trackable syntax.
|
| build(self, input_shape)
| Builds the model based on input shapes received.
|
| This is to be used for subclassed models, which do not know at
| instantiation time what their inputs look like.
|
| This method only exists for users who want to call `model.build()` in a
| standalone way (as a substitute for calling the model on real data to
| build it). It will never be called by the framework (and thus it will
| never throw unexpected errors in an unrelated workflow).
|
| Args:
| input_shape: Single tuple, `TensorShape` instance, or list/dict of
| shapes, where shapes are tuples, integers, or `TensorShape`
| instances.
|
| Raises:
| ValueError:
| 1. In case of invalid user-provided data (not of type tuple,
| list, `TensorShape`, or dict).
| 2. If the model requires call arguments that are agnostic
| to the input shapes (positional or keyword arg in call
| signature).
| 3. If not all layers were properly built.
| 4. If float type inputs are not supported within the layers.
|
| In each of these cases, the user should build their model by calling
| it on real tensor data.
|
| compile_from_config(self, config)
| Compiles the model with the information given in config.
|
| This method uses the information in the config (optimizer, loss,
| metrics, etc.) to compile the model.
|
| Args:
| config: Dict containing information for compiling the model.
|
| compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None)
| Compute the total loss, validate it, and return it.
|
| Subclasses can optionally override this method to provide custom loss
| computation logic.
|
| Example:
| ```python
| class MyModel(tf.keras.Model):
|
| def __init__(self, *args, **kwargs):
| super(MyModel, self).__init__(*args, **kwargs)
| self.loss_tracker = tf.keras.metrics.Mean(name='loss')
|
| def compute_loss(self, x, y, y_pred, sample_weight):
| loss = tf.reduce_mean(tf.math.squared_difference(y_pred, y))
| loss += tf.add_n(self.losses)
| self.loss_tracker.update_state(loss)
| return loss
|
| def reset_metrics(self):
| self.loss_tracker.reset_states()
|
| @property
| def metrics(self):
| return [self.loss_tracker]
|
| tensors = tf.random.uniform((10, 10)), tf.random.uniform((10,))
| dataset = tf.data.Dataset.from_tensor_slices(tensors).repeat().batch(1)
|
| inputs = tf.keras.layers.Input(shape=(10,), name='my_input')
| outputs = tf.keras.layers.Dense(10)(inputs)
| model = MyModel(inputs, outputs)
| model.add_loss(tf.reduce_sum(outputs))
|
| optimizer = tf.keras.optimizers.SGD()
| model.compile(optimizer, loss='mse', steps_per_execution=10)
| model.fit(dataset, epochs=2, steps_per_epoch=10)
| print('My custom loss: ', model.loss_tracker.result().numpy())
| ```
|
| Args:
| x: Input data.
| y: Target data.
| y_pred: Predictions returned by the model (output of `model(x)`)
| sample_weight: Sample weights for weighting the loss function.
|
| Returns:
| The total loss as a `tf.Tensor`, or `None` if no loss results (which
| is the case when called by `Model.test_step`).
|
| compute_metrics(self, x, y, y_pred, sample_weight)
| Update metric states and collect all metrics to be returned.
|
| Subclasses can optionally override this method to provide custom metric
| updating and collection logic.
|
| Example:
| ```python
| class MyModel(tf.keras.Sequential):
|
| def compute_metrics(self, x, y, y_pred, sample_weight):
|
| # This super call updates `self.compiled_metrics` and returns
| # results for all metrics listed in `self.metrics`.
| metric_results = super(MyModel, self).compute_metrics(
| x, y, y_pred, sample_weight)
|
| # Note that `self.custom_metric` is not listed in `self.metrics`.
| self.custom_metric.update_state(x, y, y_pred, sample_weight)
| metric_results['custom_metric_name'] = self.custom_metric.result()
| return metric_results
| ```
|
| Args:
| x: Input data.
| y: Target data.
| y_pred: Predictions returned by the model (output of `model.call(x)`)
| sample_weight: Sample weights for weighting the loss function.
|
| Returns:
| A `dict` containing values that will be passed to
| `tf.keras.callbacks.CallbackList.on_train_batch_end()`. Typically, the
| values of the metrics listed in `self.metrics` are returned. Example:
| `{'loss': 0.2, 'accuracy': 0.7}`.
|
| evaluate(self, x=None, y=None, batch_size=None, verbose='auto', sample_weight=None, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False, return_dict=False, **kwargs)
| Returns the loss value & metrics values for the model in test mode.
|
| Computation is done in batches (see the `batch_size` arg.)
|
| Args:
| x: Input data. It could be:
| - A Numpy array (or array-like), or a list of arrays
| (in case the model has multiple inputs).
| - A TensorFlow tensor, or a list of tensors
| (in case the model has multiple inputs).
| - A dict mapping input names to the corresponding array/tensors,
| if the model has named inputs.
| - A `tf.data` dataset. Should return a tuple
| of either `(inputs, targets)` or
| `(inputs, targets, sample_weights)`.
| - A generator or `keras.utils.Sequence` returning `(inputs,
| targets)` or `(inputs, targets, sample_weights)`.
| A more detailed description of unpacking behavior for iterator
| types (Dataset, generator, Sequence) is given in the `Unpacking
| behavior for iterator-like inputs` section of `Model.fit`.
| y: Target data. Like the input data `x`, it could be either Numpy
| array(s) or TensorFlow tensor(s). It should be consistent with `x`
| (you cannot have Numpy inputs and tensor targets, or inversely).
| If `x` is a dataset, generator or `keras.utils.Sequence` instance,
| `y` should not be specified (since targets will be obtained from
| the iterator/dataset).
| batch_size: Integer or `None`. Number of samples per batch of
| computation. If unspecified, `batch_size` will default to 32. Do
| not specify the `batch_size` if your data is in the form of a
| dataset, generators, or `keras.utils.Sequence` instances (since
| they generate batches).
| verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
| 0 = silent, 1 = progress bar, 2 = single line.
| `"auto"` becomes 1 for most cases, and to 2 when used with
| `ParameterServerStrategy`. Note that the progress bar is not
| particularly useful when logged to a file, so `verbose=2` is
| recommended when not running interactively (e.g. in a production
| environment). Defaults to 'auto'.
| sample_weight: Optional Numpy array of weights for the test samples,
| used for weighting the loss function. You can either pass a flat
| (1D) Numpy array with the same length as the input samples
| (1:1 mapping between weights and samples), or in the case of
| temporal data, you can pass a 2D array with shape `(samples,
| sequence_length)`, to apply a different weight to every
| timestep of every sample. This argument is not supported when
| `x` is a dataset, instead pass sample weights as the third
| element of `x`.
| steps: Integer or `None`. Total number of steps (batches of samples)
| before declaring the evaluation round finished. Ignored with the
| default value of `None`. If x is a `tf.data` dataset and `steps`
| is None, 'evaluate' will run until the dataset is exhausted. This
| argument is not supported with array inputs.
| callbacks: List of `keras.callbacks.Callback` instances. List of
| callbacks to apply during evaluation. See
| [callbacks](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks).
| max_queue_size: Integer. Used for generator or
| `keras.utils.Sequence` input only. Maximum size for the generator
| queue. If unspecified, `max_queue_size` will default to 10.
| workers: Integer. Used for generator or `keras.utils.Sequence` input
| only. Maximum number of processes to spin up when using
| process-based threading. If unspecified, `workers` will default to
| 1.
| use_multiprocessing: Boolean. Used for generator or
| `keras.utils.Sequence` input only. If `True`, use process-based
| threading. If unspecified, `use_multiprocessing` will default to
| `False`. Note that because this implementation relies on
| multiprocessing, you should not pass non-picklable arguments to
| the generator as they can't be passed easily to children
| processes.
| return_dict: If `True`, loss and metric results are returned as a
| dict, with each key being the name of the metric. If `False`, they
| are returned as a list.
| **kwargs: Unused at this time.
|
| See the discussion of `Unpacking behavior for iterator-like inputs` for
| `Model.fit`.
|
| Returns:
| Scalar test loss (if the model has a single output and no metrics)
| or list of scalars (if the model has multiple outputs
| and/or metrics). The attribute `model.metrics_names` will give you
| the display labels for the scalar outputs.
|
| Raises:
| RuntimeError: If `model.evaluate` is wrapped in a `tf.function`.
|
| evaluate_generator(self, generator, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)
| Evaluates the model on a data generator.
|
| DEPRECATED:
| `Model.evaluate` now supports generators, so there is no longer any
| need to use this endpoint.
|
| export(self, filepath)
| Create a SavedModel artifact for inference (e.g. via TF-Serving).
|
| This method lets you export a model to a lightweight SavedModel artifact
| that contains the model's forward pass only (its `call()` method)
| and can be served via e.g. TF-Serving. The forward pass is registered
| under the name `serve()` (see example below).
|
| The original code of the model (including any custom layers you may
| have used) is *no longer* necessary to reload the artifact -- it is
| entirely standalone.
|
| Args:
| filepath: `str` or `pathlib.Path` object. Path where to save
| the artifact.
|
| Example:
|
| ```python
| # Create the artifact
| model.export("path/to/location")
|
| # Later, in a different process / environment...
| reloaded_artifact = tf.saved_model.load("path/to/location")
| predictions = reloaded_artifact.serve(input_data)
| ```
|
| If you would like to customize your serving endpoints, you can
| use the lower-level `keras.export.ExportArchive` class. The `export()`
| method relies on `ExportArchive` internally.
|
| fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, validation_freq=1, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
| Fits the model on data yielded batch-by-batch by a Python generator.
|
| DEPRECATED:
| `Model.fit` now supports generators, so there is no longer any need to
| use this endpoint.
|
| get_compile_config(self)
| Returns a serialized config with information for compiling the model.
|
| This method returns a config dictionary containing all the information
| (optimizer, loss, metrics, etc.) with which the model was compiled.
|
| Returns:
| A dict containing information for compiling the model.
|
| get_config(self)
| Returns the config of the `Model`.
|
| Config is a Python dictionary (serializable) containing the
| configuration of an object, which in this case is a `Model`. This allows
| the `Model` to be be reinstantiated later (without its trained weights)
| from this configuration.
|
| Note that `get_config()` does not guarantee to return a fresh copy of
| dict every time it is called. The callers should make a copy of the
| returned dict if they want to modify it.
|
| Developers of subclassed `Model` are advised to override this method,
| and continue to update the dict from `super(MyModel, self).get_config()`
| to provide the proper configuration of this `Model`. The default config
| will return config dict for init parameters if they are basic types.
| Raises `NotImplementedError` when in cases where a custom
| `get_config()` implementation is required for the subclassed model.
|
| Returns:
| Python dictionary containing the configuration of this `Model`.
|
| get_layer(self, name=None, index=None)
| Retrieves a layer based on either its name (unique) or index.
|
| If `name` and `index` are both provided, `index` will take precedence.
| Indices are based on order of horizontal graph traversal (bottom-up).
|
| Args:
| name: String, name of layer.
| index: Integer, index of layer.
|
| Returns:
| A layer instance.
|
| get_metrics_result(self)
| Returns the model's metrics values as a dict.
|
| If any of the metric result is a dict (containing multiple metrics),
| each of them gets added to the top level returned dict of this method.
|
| Returns:
| A `dict` containing values of the metrics listed in `self.metrics`.
| Example:
| `{'loss': 0.2, 'accuracy': 0.7}`.
|
| get_weight_paths(self)
| Retrieve all the variables and their paths for the model.
|
| The variable path (string) is a stable key to identify a `tf.Variable`
| instance owned by the model. It can be used to specify variable-specific
| configurations (e.g. DTensor, quantization) from a global view.
|
| This method returns a dict with weight object paths as keys
| and the corresponding `tf.Variable` instances as values.
|
| Note that if the model is a subclassed model and the weights haven't
| been initialized, an empty dict will be returned.
|
| Returns:
| A dict where keys are variable paths and values are `tf.Variable`
| instances.
|
| Example:
|
| ```python
| class SubclassModel(tf.keras.Model):
|
| def __init__(self, name=None):
| super().__init__(name=name)
| self.d1 = tf.keras.layers.Dense(10)
| self.d2 = tf.keras.layers.Dense(20)
|
| def call(self, inputs):
| x = self.d1(inputs)
| return self.d2(x)
|
| model = SubclassModel()
| model(tf.zeros((10, 10)))
| weight_paths = model.get_weight_paths()
| # weight_paths:
| # {
| # 'd1.kernel': model.d1.kernel,
| # 'd1.bias': model.d1.bias,
| # 'd2.kernel': model.d2.kernel,
| # 'd2.bias': model.d2.bias,
| # }
|
| # Functional model
| inputs = tf.keras.Input((10,), batch_size=10)
| x = tf.keras.layers.Dense(20, name='d1')(inputs)
| output = tf.keras.layers.Dense(30, name='d2')(x)
| model = tf.keras.Model(inputs, output)
| d1 = model.layers[1]
| d2 = model.layers[2]
| weight_paths = model.get_weight_paths()
| # weight_paths:
| # {
| # 'd1.kernel': d1.kernel,
| # 'd1.bias': d1.bias,
| # 'd2.kernel': d2.kernel,
| # 'd2.bias': d2.bias,
| # }
| ```
|
| get_weights(self)
| Retrieves the weights of the model.
|
| Returns:
| A flat list of Numpy arrays.
|
| make_train_function(self, force=False)
| Creates a function that executes one step of training.
|
| This method can be overridden to support custom training logic.
| This method is called by `Model.fit` and `Model.train_on_batch`.
|
| Typically, this method directly controls `tf.function` and
| `tf.distribute.Strategy` settings, and delegates the actual training
| logic to `Model.train_step`.
|
| This function is cached the first time `Model.fit` or
| `Model.train_on_batch` is called. The cache is cleared whenever
| `Model.compile` is called. You can skip the cache and generate again the
| function with `force=True`.
|
| Args:
| force: Whether to regenerate the train function and skip the cached
| function if available.
|
| Returns:
| Function. The function created by this method should accept a
| `tf.data.Iterator`, and return a `dict` containing values that will
| be passed to `tf.keras.Callbacks.on_train_batch_end`, such as
| `{'loss': 0.2, 'accuracy': 0.7}`.
|
| predict(self, x, batch_size=None, verbose='auto', steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False)
| Generates output predictions for the input samples.
|
| Computation is done in batches. This method is designed for batch
| processing of large numbers of inputs. It is not intended for use inside
| of loops that iterate over your data and process small numbers of inputs
| at a time.
|
| For small numbers of inputs that fit in one batch,
| directly use `__call__()` for faster execution, e.g.,
| `model(x)`, or `model(x, training=False)` if you have layers such as
| `tf.keras.layers.BatchNormalization` that behave differently during
| inference. You may pair the individual model call with a `tf.function`
| for additional performance inside your inner loop.
| If you need access to numpy array values instead of tensors after your
| model call, you can use `tensor.numpy()` to get the numpy array value of
| an eager tensor.
|
| Also, note the fact that test loss is not affected by
| regularization layers like noise and dropout.
|
| Note: See [this FAQ entry](
| https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call)
| for more details about the difference between `Model` methods
| `predict()` and `__call__()`.
|
| Args:
| x: Input samples. It could be:
| - A Numpy array (or array-like), or a list of arrays
| (in case the model has multiple inputs).
| - A TensorFlow tensor, or a list of tensors
| (in case the model has multiple inputs).
| - A `tf.data` dataset.
| - A generator or `keras.utils.Sequence` instance.
| A more detailed description of unpacking behavior for iterator
| types (Dataset, generator, Sequence) is given in the `Unpacking
| behavior for iterator-like inputs` section of `Model.fit`.
| batch_size: Integer or `None`.
| Number of samples per batch.
| If unspecified, `batch_size` will default to 32.
| Do not specify the `batch_size` if your data is in the
| form of dataset, generators, or `keras.utils.Sequence` instances
| (since they generate batches).
| verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
| 0 = silent, 1 = progress bar, 2 = single line.
| `"auto"` becomes 1 for most cases, and to 2 when used with
| `ParameterServerStrategy`. Note that the progress bar is not
| particularly useful when logged to a file, so `verbose=2` is
| recommended when not running interactively (e.g. in a production
| environment). Defaults to 'auto'.
| steps: Total number of steps (batches of samples)
| before declaring the prediction round finished.
| Ignored with the default value of `None`. If x is a `tf.data`
| dataset and `steps` is None, `predict()` will
| run until the input dataset is exhausted.
| callbacks: List of `keras.callbacks.Callback` instances.
| List of callbacks to apply during prediction.
| See [callbacks](
| https://www.tensorflow.org/api_docs/python/tf/keras/callbacks).
| max_queue_size: Integer. Used for generator or
| `keras.utils.Sequence` input only. Maximum size for the
| generator queue. If unspecified, `max_queue_size` will default
| to 10.
| workers: Integer. Used for generator or `keras.utils.Sequence` input
| only. Maximum number of processes to spin up when using
| process-based threading. If unspecified, `workers` will default
| to 1.
| use_multiprocessing: Boolean. Used for generator or
| `keras.utils.Sequence` input only. If `True`, use process-based
| threading. If unspecified, `use_multiprocessing` will default to
| `False`. Note that because this implementation relies on
| multiprocessing, you should not pass non-picklable arguments to
| the generator as they can't be passed easily to children
| processes.
|
| See the discussion of `Unpacking behavior for iterator-like inputs` for
| `Model.fit`. Note that Model.predict uses the same interpretation rules
| as `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for
| all three methods.
|
| Returns:
| Numpy array(s) of predictions.
|
| Raises:
| RuntimeError: If `model.predict` is wrapped in a `tf.function`.
| ValueError: In case of mismatch between the provided
| input data and the model's expectations,
| or in case a stateful model receives a number of samples
| that is not a multiple of the batch size.
|
| predict_generator(self, generator, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)
| Generates predictions for the input samples from a data generator.
|
| DEPRECATED:
| `Model.predict` now supports generators, so there is no longer any
| need to use this endpoint.
|
| predict_on_batch(self, x)
| Returns predictions for a single batch of samples.
|
| Args:
| x: Input data. It could be:
| - A Numpy array (or array-like), or a list of arrays (in case the
| model has multiple inputs).
| - A TensorFlow tensor, or a list of tensors (in case the model has
| multiple inputs).
|
| Returns:
| Numpy array(s) of predictions.
|
| Raises:
| RuntimeError: If `model.predict_on_batch` is wrapped in a
| `tf.function`.
|
| predict_step(self, data)
| The logic for one inference step.
|
| This method can be overridden to support custom inference logic.
| This method is called by `Model.make_predict_function`.
|
| This method should contain the mathematical logic for one step of
| inference. This typically includes the forward pass.
|
| Configuration details for *how* this logic is run (e.g. `tf.function`
| and `tf.distribute.Strategy` settings), should be left to
| `Model.make_predict_function`, which can also be overridden.
|
| Args:
| data: A nested structure of `Tensor`s.
|
| Returns:
| The result of one inference step, typically the output of calling the
| `Model` on data.
|
| reset_metrics(self)
| Resets the state of all the metrics in the model.
|
| Examples:
|
| >>> inputs = tf.keras.layers.Input(shape=(3,))
| >>> outputs = tf.keras.layers.Dense(2)(inputs)
| >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
| >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
|
| >>> x = np.random.random((2, 3))
| >>> y = np.random.randint(0, 2, (2, 2))
| >>> _ = model.fit(x, y, verbose=0)
| >>> assert all(float(m.result()) for m in model.metrics)
|
| >>> model.reset_metrics()
| >>> assert all(float(m.result()) == 0 for m in model.metrics)
|
| reset_states(self)
|
| save_spec(self, dynamic_batch=True)
| Returns the `tf.TensorSpec` of call args as a tuple `(args, kwargs)`.
|
| This value is automatically defined after calling the model for the
| first time. Afterwards, you can use it when exporting the model for
| serving:
|
| ```python
| model = tf.keras.Model(...)
|
| @tf.function
| def serve(*args, **kwargs):
| outputs = model(*args, **kwargs)
| # Apply postprocessing steps, or add additional outputs.
| ...
| return outputs
|
| # arg_specs is `[tf.TensorSpec(...), ...]`. kwarg_specs, in this
| # example, is an empty dict since functional models do not use keyword
| # arguments.
| arg_specs, kwarg_specs = model.save_spec()
|
| model.save(path, signatures={
| 'serving_default': serve.get_concrete_function(*arg_specs,
| **kwarg_specs)
| })
| ```
|
| Args:
| dynamic_batch: Whether to set the batch sizes of all the returned
| `tf.TensorSpec` to `None`. (Note that when defining functional or
| Sequential models with `tf.keras.Input([...], batch_size=X)`, the
| batch size will always be preserved). Defaults to `True`.
| Returns:
| If the model inputs are defined, returns a tuple `(args, kwargs)`. All
| elements in `args` and `kwargs` are `tf.TensorSpec`.
| If the model inputs are not defined, returns `None`.
| The model inputs are automatically set when calling the model,
| `model.fit`, `model.evaluate` or `model.predict`.
|
| save_weights(self, filepath, overwrite=True, save_format=None, options=None)
| Saves all layer weights.
|
| Either saves in HDF5 or in TensorFlow format based on the `save_format`
| argument.
|
| When saving in HDF5 format, the weight file has:
| - `layer_names` (attribute), a list of strings
| (ordered names of model layers).
| - For every layer, a `group` named `layer.name`
| - For every such layer group, a group attribute `weight_names`,
| a list of strings
| (ordered names of weights tensor of the layer).
| - For every weight in the layer, a dataset
| storing the weight value, named after the weight tensor.
|
| When saving in TensorFlow format, all objects referenced by the network
| are saved in the same format as `tf.train.Checkpoint`, including any
| `Layer` instances or `Optimizer` instances assigned to object
| attributes. For networks constructed from inputs and outputs using
| `tf.keras.Model(inputs, outputs)`, `Layer` instances used by the network
| are tracked/saved automatically. For user-defined classes which inherit
| from `tf.keras.Model`, `Layer` instances must be assigned to object
| attributes, typically in the constructor. See the documentation of
| `tf.train.Checkpoint` and `tf.keras.Model` for details.
|
| While the formats are the same, do not mix `save_weights` and
| `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should
| be loaded using `Model.load_weights`. Checkpoints saved using
| `tf.train.Checkpoint.save` should be restored using the corresponding
| `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
| `save_weights` for training checkpoints.
|
| The TensorFlow format matches objects and variables by starting at a
| root object, `self` for `save_weights`, and greedily matching attribute
| names. For `Model.save` this is the `Model`, and for `Checkpoint.save`
| this is the `Checkpoint` even if the `Checkpoint` has a model attached.
| This means saving a `tf.keras.Model` using `save_weights` and loading
| into a `tf.train.Checkpoint` with a `Model` attached (or vice versa)
| will not match the `Model`'s variables. See the
| [guide to training checkpoints](
| https://www.tensorflow.org/guide/checkpoint) for details on
| the TensorFlow format.
|
| Args:
| filepath: String or PathLike, path to the file to save the weights
| to. When saving in TensorFlow format, this is the prefix used
| for checkpoint files (multiple files are generated). Note that
| the '.h5' suffix causes weights to be saved in HDF5 format.
| overwrite: Whether to silently overwrite any existing file at the
| target location, or provide the user with a manual prompt.
| save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
| '.keras' will default to HDF5 if `save_format` is `None`.
| Otherwise, `None` becomes 'tf'. Defaults to `None`.
| options: Optional `tf.train.CheckpointOptions` object that specifies
| options for saving weights.
|
| Raises:
| ImportError: If `h5py` is not available when attempting to save in
| HDF5 format.
|
| test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True, return_dict=False)
| Test the model on a single batch of samples.
|
| Args:
| x: Input data. It could be:
| - A Numpy array (or array-like), or a list of arrays (in case the
| model has multiple inputs).
| - A TensorFlow tensor, or a list of tensors (in case the model has
| multiple inputs).
| - A dict mapping input names to the corresponding array/tensors,
| if the model has named inputs.
| y: Target data. Like the input data `x`, it could be either Numpy
| array(s) or TensorFlow tensor(s). It should be consistent with `x`
| (you cannot have Numpy inputs and tensor targets, or inversely).
| sample_weight: Optional array of the same length as x, containing
| weights to apply to the model's loss for each sample. In the case
| of temporal data, you can pass a 2D array with shape (samples,
| sequence_length), to apply a different weight to every timestep of
| every sample.
| reset_metrics: If `True`, the metrics returned will be only for this
| batch. If `False`, the metrics will be statefully accumulated
| across batches.
| return_dict: If `True`, loss and metric results are returned as a
| dict, with each key being the name of the metric. If `False`, they
| are returned as a list.
|
| Returns:
| Scalar test loss (if the model has a single output and no metrics)
| or list of scalars (if the model has multiple outputs
| and/or metrics). The attribute `model.metrics_names` will give you
| the display labels for the scalar outputs.
|
| Raises:
| RuntimeError: If `model.test_on_batch` is wrapped in a
| `tf.function`.
|
| test_step(self, data)
| The logic for one evaluation step.
|
| This method can be overridden to support custom evaluation logic.
| This method is called by `Model.make_test_function`.
|
| This function should contain the mathematical logic for one step of
| evaluation.
| This typically includes the forward pass, loss calculation, and metrics
| updates.
|
| Configuration details for *how* this logic is run (e.g. `tf.function`
| and `tf.distribute.Strategy` settings), should be left to
| `Model.make_test_function`, which can also be overridden.
|
| Args:
| data: A nested structure of `Tensor`s.
|
| Returns:
| A `dict` containing values that will be passed to
| `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
| values of the `Model`'s metrics are returned.
|
| to_json(self, **kwargs)
| Returns a JSON string containing the network configuration.
|
| To load a network from a JSON save file, use
| `keras.models.model_from_json(json_string, custom_objects={})`.
|
| Args:
| **kwargs: Additional keyword arguments to be passed to
| *`json.dumps()`.
|
| Returns:
| A JSON string.
|
| to_yaml(self, **kwargs)
| Returns a yaml string containing the network configuration.
|
| Note: Since TF 2.6, this method is no longer supported and will raise a
| RuntimeError.
|
| To load a network from a yaml save file, use
| `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
|
| `custom_objects` should be a dictionary mapping
| the names of custom losses / layers / etc to the corresponding
| functions / classes.
|
| Args:
| **kwargs: Additional keyword arguments
| to be passed to `yaml.dump()`.
|
| Returns:
| A YAML string.
|
| Raises:
| RuntimeError: announces that the method poses a security risk
|
| ----------------------------------------------------------------------
| Class methods inherited from keras.src.engine.training.Model:
|
| from_config(config, custom_objects=None) from builtins.type
| Creates a layer from its config.
|
| This method is the reverse of `get_config`,
| capable of instantiating the same layer from the config
| dictionary. It does not handle layer connectivity
| (handled by Network), nor weights (handled by `set_weights`).
|
| Args:
| config: A Python dictionary, typically the
| output of get_config.
|
| Returns:
| A layer instance.
|
| ----------------------------------------------------------------------
| Static methods inherited from keras.src.engine.training.Model:
|
| __new__(cls, *args, **kwargs)
| Create and return a new object. See help(type) for accurate signature.
|
| ----------------------------------------------------------------------
| Readonly properties inherited from keras.src.engine.training.Model:
|
| distribute_strategy
| The `tf.distribute.Strategy` this model was created under.
|
| metrics
| Return metrics added using `compile()` or `add_metric()`.
|
| Note: Metrics passed to `compile()` are available only after a
| `keras.Model` has been trained/evaluated on actual data.
|
| Examples:
|
| >>> inputs = tf.keras.layers.Input(shape=(3,))
| >>> outputs = tf.keras.layers.Dense(2)(inputs)
| >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
| >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
| >>> [m.name for m in model.metrics]
| []
|
| >>> x = np.random.random((2, 3))
| >>> y = np.random.randint(0, 2, (2, 2))
| >>> model.fit(x, y)
| >>> [m.name for m in model.metrics]
| ['loss', 'mae']
|
| >>> inputs = tf.keras.layers.Input(shape=(3,))
| >>> d = tf.keras.layers.Dense(2, name='out')
| >>> output_1 = d(inputs)
| >>> output_2 = d(inputs)
| >>> model = tf.keras.models.Model(
| ... inputs=inputs, outputs=[output_1, output_2])
| >>> model.add_metric(
| ... tf.reduce_sum(output_2), name='mean', aggregation='mean')
| >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
| >>> model.fit(x, (y, y))
| >>> [m.name for m in model.metrics]
| ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
| 'out_1_acc', 'mean']
|
| metrics_names
| Returns the model's display labels for all outputs.
|
| Note: `metrics_names` are available only after a `keras.Model` has been
| trained/evaluated on actual data.
|
| Examples:
|
| >>> inputs = tf.keras.layers.Input(shape=(3,))
| >>> outputs = tf.keras.layers.Dense(2)(inputs)
| >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
| >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
| >>> model.metrics_names
| []
|
| >>> x = np.random.random((2, 3))
| >>> y = np.random.randint(0, 2, (2, 2))
| >>> model.fit(x, y)
| >>> model.metrics_names
| ['loss', 'mae']
|
| >>> inputs = tf.keras.layers.Input(shape=(3,))
| >>> d = tf.keras.layers.Dense(2, name='out')
| >>> output_1 = d(inputs)
| >>> output_2 = d(inputs)
| >>> model = tf.keras.models.Model(
| ... inputs=inputs, outputs=[output_1, output_2])
| >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
| >>> model.fit(x, (y, y))
| >>> model.metrics_names
| ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
| 'out_1_acc']
|
| non_trainable_weights
| List of all non-trainable weights tracked by this layer.
|
| Non-trainable weights are *not* updated during training. They are
| expected to be updated manually in `call()`.
|
| Returns:
| A list of non-trainable variables.
|
| state_updates
| Deprecated, do NOT use!
|
| Returns the `updates` from all layers that are stateful.
|
| This is useful for separating training updates and
| state updates, e.g. when we need to update a layer's internal state
| during prediction.
|
| Returns:
| A list of update ops.
|
| trainable_weights
| List of all trainable weights tracked by this layer.
|
| Trainable weights are updated via gradient descent during training.
|
| Returns:
| A list of trainable variables.
|
| weights
| Returns the list of all layer variables/weights.
|
| Note: This will not track the weights of nested `tf.Modules` that are
| not themselves Keras layers.
|
| Returns:
| A list of variables.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from keras.src.engine.training.Model:
|
| distribute_reduction_method
| The method employed to reduce per-replica values during training.
|
| Unless specified, the value "auto" will be assumed, indicating that
| the reduction strategy should be chosen based on the current
| running environment.
| See `reduce_per_replica` function for more details.
|
| jit_compile
| Specify whether to compile the model with XLA.
|
| [XLA](https://www.tensorflow.org/xla) is an optimizing compiler
| for machine learning. `jit_compile` is not enabled by default.
| Note that `jit_compile=True` may not necessarily work for all models.
|
| For more information on supported operations please refer to the
| [XLA documentation](https://www.tensorflow.org/xla). Also refer to
| [known XLA issues](https://www.tensorflow.org/xla/known_issues)
| for more details.
|
| layers
|
| run_eagerly
| Settable attribute indicating whether the model should run eagerly.
|
| Running eagerly means that your model will be run step by step,
| like Python code. Your model might run slower, but it should become
| easier for you to debug it by stepping into individual layer calls.
|
| By default, we will attempt to compile your model to a static graph to
| deliver the best execution performance.
|
| Returns:
| Boolean, whether the model should run eagerly.
|
| ----------------------------------------------------------------------
| Methods inherited from keras.src.engine.base_layer.Layer:
|
| __delattr__(self, name)
| Implement delattr(self, name).
|
| __getstate__(self)
|
| __setstate__(self, state)
|
| add_loss(self, losses, **kwargs)
| Add loss tensor(s), potentially dependent on layer inputs.
|
| Some losses (for instance, activity regularization losses) may be
| dependent on the inputs passed when calling a layer. Hence, when reusing
| the same layer on different inputs `a` and `b`, some entries in
| `layer.losses` may be dependent on `a` and some on `b`. This method
| automatically keeps track of dependencies.
|
| This method can be used inside a subclassed layer or model's `call`
| function, in which case `losses` should be a Tensor or list of Tensors.
|
| Example:
|
| ```python
| class MyLayer(tf.keras.layers.Layer):
| def call(self, inputs):
| self.add_loss(tf.abs(tf.reduce_mean(inputs)))
| return inputs
| ```
|
| The same code works in distributed training: the input to `add_loss()`
| is treated like a regularization loss and averaged across replicas
| by the training loop (both built-in `Model.fit()` and compliant custom
| training loops).
|
| The `add_loss` method can also be called directly on a Functional Model
| during construction. In this case, any loss Tensors passed to this Model
| must be symbolic and be able to be traced back to the model's `Input`s.
| These losses become part of the model's topology and are tracked in
| `get_config`.
|
| Example:
|
| ```python
| inputs = tf.keras.Input(shape=(10,))
| x = tf.keras.layers.Dense(10)(inputs)
| outputs = tf.keras.layers.Dense(1)(x)
| model = tf.keras.Model(inputs, outputs)
| # Activity regularization.
| model.add_loss(tf.abs(tf.reduce_mean(x)))
| ```
|
| If this is not the case for your loss (if, for example, your loss
| references a `Variable` of one of the model's layers), you can wrap your
| loss in a zero-argument lambda. These losses are not tracked as part of
| the model's topology since they can't be serialized.
|
| Example:
|
| ```python
| inputs = tf.keras.Input(shape=(10,))
| d = tf.keras.layers.Dense(10)
| x = d(inputs)
| outputs = tf.keras.layers.Dense(1)(x)
| model = tf.keras.Model(inputs, outputs)
| # Weight regularization.
| model.add_loss(lambda: tf.reduce_mean(d.kernel))
| ```
|
| Args:
| losses: Loss tensor, or list/tuple of tensors. Rather than tensors,
| losses may also be zero-argument callables which create a loss
| tensor.
| **kwargs: Used for backwards compatibility only.
|
| add_metric(self, value, name=None, **kwargs)
| Adds metric tensor to the layer.
|
| This method can be used inside the `call()` method of a subclassed layer
| or model.
|
| ```python
| class MyMetricLayer(tf.keras.layers.Layer):
| def __init__(self):
| super(MyMetricLayer, self).__init__(name='my_metric_layer')
| self.mean = tf.keras.metrics.Mean(name='metric_1')
|
| def call(self, inputs):
| self.add_metric(self.mean(inputs))
| self.add_metric(tf.reduce_sum(inputs), name='metric_2')
| return inputs
| ```
|
| This method can also be called directly on a Functional Model during
| construction. In this case, any tensor passed to this Model must
| be symbolic and be able to be traced back to the model's `Input`s. These
| metrics become part of the model's topology and are tracked when you
| save the model via `save()`.
|
| ```python
| inputs = tf.keras.Input(shape=(10,))
| x = tf.keras.layers.Dense(10)(inputs)
| outputs = tf.keras.layers.Dense(1)(x)
| model = tf.keras.Model(inputs, outputs)
| model.add_metric(math_ops.reduce_sum(x), name='metric_1')
| ```
|
| Note: Calling `add_metric()` with the result of a metric object on a
| Functional Model, as shown in the example below, is not supported. This
| is because we cannot trace the metric result tensor back to the model's
| inputs.
|
| ```python
| inputs = tf.keras.Input(shape=(10,))
| x = tf.keras.layers.Dense(10)(inputs)
| outputs = tf.keras.layers.Dense(1)(x)
| model = tf.keras.Model(inputs, outputs)
| model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1')
| ```
|
| Args:
| value: Metric tensor.
| name: String metric name.
| **kwargs: Additional keyword arguments for backward compatibility.
| Accepted values:
| `aggregation` - When the `value` tensor provided is not the result
| of calling a `keras.Metric` instance, it will be aggregated by
| default using a `keras.Metric.Mean`.
|
| add_update(self, updates)
| Add update op(s), potentially dependent on layer inputs.
|
| Weight updates (for instance, the updates of the moving mean and
| variance in a BatchNormalization layer) may be dependent on the inputs
| passed when calling a layer. Hence, when reusing the same layer on
| different inputs `a` and `b`, some entries in `layer.updates` may be
| dependent on `a` and some on `b`. This method automatically keeps track
| of dependencies.
|
| This call is ignored when eager execution is enabled (in that case,
| variable updates are run on the fly and thus do not need to be tracked
| for later execution).
|
| Args:
| updates: Update op, or list/tuple of update ops, or zero-arg callable
| that returns an update op. A zero-arg callable should be passed in
| order to disable running the updates by setting `trainable=False`
| on this Layer, when executing in Eager mode.
|
| add_variable(self, *args, **kwargs)
| Deprecated, do NOT use! Alias for `add_weight`.
|
| add_weight(self, name=None, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, constraint=None, use_resource=None, synchronization=<VariableSynchronization.AUTO: 0>, aggregation=<VariableAggregationV2.NONE: 0>, **kwargs)
| Adds a new variable to the layer.
|
| Args:
| name: Variable name.
| shape: Variable shape. Defaults to scalar if unspecified.
| dtype: The type of the variable. Defaults to `self.dtype`.
| initializer: Initializer instance (callable).
| regularizer: Regularizer instance (callable).
| trainable: Boolean, whether the variable should be part of the layer's
| "trainable_variables" (e.g. variables, biases)
| or "non_trainable_variables" (e.g. BatchNorm mean and variance).
| Note that `trainable` cannot be `True` if `synchronization`
| is set to `ON_READ`.
| constraint: Constraint instance (callable).
| use_resource: Whether to use a `ResourceVariable` or not.
| See [this guide](
| https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables)
| for more information.
| synchronization: Indicates when a distributed a variable will be
| aggregated. Accepted values are constants defined in the class
| `tf.VariableSynchronization`. By default the synchronization is set
| to `AUTO` and the current `DistributionStrategy` chooses when to
| synchronize. If `synchronization` is set to `ON_READ`, `trainable`
| must not be set to `True`.
| aggregation: Indicates how a distributed variable will be aggregated.
| Accepted values are constants defined in the class
| `tf.VariableAggregation`.
| **kwargs: Additional keyword arguments. Accepted values are `getter`,
| `collections`, `experimental_autocast` and `caching_device`.
|
| Returns:
| The variable created.
|
| Raises:
| ValueError: When giving unsupported dtype and no initializer or when
| trainable has been set to True with synchronization set as
| `ON_READ`.
|
| build_from_config(self, config)
| Builds the layer's states with the supplied config dict.
|
| By default, this method calls the `build(config["input_shape"])` method,
| which creates weights based on the layer's input shape in the supplied
| config. If your config contains other information needed to load the
| layer's state, you should override this method.
|
| Args:
| config: Dict containing the input shape associated with this layer.
|
| compute_mask(self, inputs, mask=None)
| Computes an output mask tensor.
|
| Args:
| inputs: Tensor or list of tensors.
| mask: Tensor or list of tensors.
|
| Returns:
| None or a tensor (or list of tensors,
| one per output tensor of the layer).
|
| compute_output_shape(self, input_shape)
| Computes the output shape of the layer.
|
| This method will cause the layer's state to be built, if that has not
| happened before. This requires that the layer will later be used with
| inputs that match the input shape provided here.
|
| Args:
| input_shape: Shape tuple (tuple of integers) or `tf.TensorShape`,
| or structure of shape tuples / `tf.TensorShape` instances
| (one per output tensor of the layer).
| Shape tuples can include None for free dimensions,
| instead of an integer.
|
| Returns:
| A `tf.TensorShape` instance
| or structure of `tf.TensorShape` instances.
|
| compute_output_signature(self, input_signature)
| Compute the output tensor signature of the layer based on the inputs.
|
| Unlike a TensorShape object, a TensorSpec object contains both shape
| and dtype information for a tensor. This method allows layers to provide
| output dtype information if it is different from the input dtype.
| For any layer that doesn't implement this function,
| the framework will fall back to use `compute_output_shape`, and will
| assume that the output dtype matches the input dtype.
|
| Args:
| input_signature: Single TensorSpec or nested structure of TensorSpec
| objects, describing a candidate input for the layer.
|
| Returns:
| Single TensorSpec or nested structure of TensorSpec objects,
| describing how the layer would transform the provided input.
|
| Raises:
| TypeError: If input_signature contains a non-TensorSpec object.
|
| count_params(self)
| Count the total number of scalars composing the weights.
|
| Returns:
| An integer count.
|
| Raises:
| ValueError: if the layer isn't yet built
| (in which case its weights aren't yet defined).
|
| finalize_state(self)
| Finalizes the layers state after updating layer weights.
|
| This function can be subclassed in a layer and will be called after
| updating a layer weights. It can be overridden to finalize any
| additional layer state after a weight update.
|
| This function will be called after weights of a layer have been restored
| from a loaded model.
|
| get_build_config(self)
| Returns a dictionary with the layer's input shape.
|
| This method returns a config dict that can be used by
| `build_from_config(config)` to create all states (e.g. Variables and
| Lookup tables) needed by the layer.
|
| By default, the config only contains the input shape that the layer
| was built with. If you're writing a custom layer that creates state in
| an unusual way, you should override this method to make sure this state
| is already created when Keras attempts to load its value upon model
| loading.
|
| Returns:
| A dict containing the input shape associated with the layer.
|
| get_input_at(self, node_index)
| Retrieves the input tensor(s) of a layer at a given node.
|
| Args:
| node_index: Integer, index of the node
| from which to retrieve the attribute.
| E.g. `node_index=0` will correspond to the
| first input node of the layer.
|
| Returns:
| A tensor (or list of tensors if the layer has multiple inputs).
|
| Raises:
| RuntimeError: If called in Eager mode.
|
| get_input_mask_at(self, node_index)
| Retrieves the input mask tensor(s) of a layer at a given node.
|
| Args:
| node_index: Integer, index of the node
| from which to retrieve the attribute.
| E.g. `node_index=0` will correspond to the
| first time the layer was called.
|
| Returns:
| A mask tensor
| (or list of tensors if the layer has multiple inputs).
|
| get_input_shape_at(self, node_index)
| Retrieves the input shape(s) of a layer at a given node.
|
| Args:
| node_index: Integer, index of the node
| from which to retrieve the attribute.
| E.g. `node_index=0` will correspond to the
| first time the layer was called.
|
| Returns:
| A shape tuple
| (or list of shape tuples if the layer has multiple inputs).
|
| Raises:
| RuntimeError: If called in Eager mode.
|
| get_output_at(self, node_index)
| Retrieves the output tensor(s) of a layer at a given node.
|
| Args:
| node_index: Integer, index of the node
| from which to retrieve the attribute.
| E.g. `node_index=0` will correspond to the
| first output node of the layer.
|
| Returns:
| A tensor (or list of tensors if the layer has multiple outputs).
|
| Raises:
| RuntimeError: If called in Eager mode.
|
| get_output_mask_at(self, node_index)
| Retrieves the output mask tensor(s) of a layer at a given node.
|
| Args:
| node_index: Integer, index of the node
| from which to retrieve the attribute.
| E.g. `node_index=0` will correspond to the
| first time the layer was called.
|
| Returns:
| A mask tensor
| (or list of tensors if the layer has multiple outputs).
|
| get_output_shape_at(self, node_index)
| Retrieves the output shape(s) of a layer at a given node.
|
| Args:
| node_index: Integer, index of the node
| from which to retrieve the attribute.
| E.g. `node_index=0` will correspond to the
| first time the layer was called.
|
| Returns:
| A shape tuple
| (or list of shape tuples if the layer has multiple outputs).
|
| Raises:
| RuntimeError: If called in Eager mode.
|
| load_own_variables(self, store)
| Loads the state of the layer.
|
| You can override this method to take full control of how the state of
| the layer is loaded upon calling `keras.models.load_model()`.
|
| Args:
| store: Dict from which the state of the model will be loaded.
|
| save_own_variables(self, store)
| Saves the state of the layer.
|
| You can override this method to take full control of how the state of
| the layer is saved upon calling `model.save()`.
|
| Args:
| store: Dict where the state of the model will be saved.
|
| set_weights(self, weights)
| Sets the weights of the layer, from NumPy arrays.
|
| The weights of a layer represent the state of the layer. This function
| sets the weight values from numpy arrays. The weight values should be
| passed in the order they are created by the layer. Note that the layer's
| weights must be instantiated before calling this function, by calling
| the layer.
|
| For example, a `Dense` layer returns a list of two values: the kernel
| matrix and the bias vector. These can be used to set the weights of
| another `Dense` layer:
|
| >>> layer_a = tf.keras.layers.Dense(1,
| ... kernel_initializer=tf.constant_initializer(1.))
| >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
| >>> layer_a.get_weights()
| [array([[1.],
| [1.],
| [1.]], dtype=float32), array([0.], dtype=float32)]
| >>> layer_b = tf.keras.layers.Dense(1,
| ... kernel_initializer=tf.constant_initializer(2.))
| >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
| >>> layer_b.get_weights()
| [array([[2.],
| [2.],
| [2.]], dtype=float32), array([0.], dtype=float32)]
| >>> layer_b.set_weights(layer_a.get_weights())
| >>> layer_b.get_weights()
| [array([[1.],
| [1.],
| [1.]], dtype=float32), array([0.], dtype=float32)]
|
| Args:
| weights: a list of NumPy arrays. The number
| of arrays and their shape must match
| number of the dimensions of the weights
| of the layer (i.e. it should match the
| output of `get_weights`).
|
| Raises:
| ValueError: If the provided weights list does not match the
| layer's specifications.
|
| ----------------------------------------------------------------------
| Readonly properties inherited from keras.src.engine.base_layer.Layer:
|
| compute_dtype
| The dtype of the layer's computations.
|
| This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless
| mixed precision is used, this is the same as `Layer.dtype`, the dtype of
| the weights.
|
| Layers automatically cast their inputs to the compute dtype, which
| causes computations and the output to be in the compute dtype as well.
| This is done by the base Layer class in `Layer.__call__`, so you do not
| have to insert these casts if implementing your own layer.
|
| Layers often perform certain internal computations in higher precision
| when `compute_dtype` is float16 or bfloat16 for numeric stability. The
| output will still typically be float16 or bfloat16 in such cases.
|
| Returns:
| The layer's compute dtype.
|
| dtype
| The dtype of the layer weights.
|
| This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless
| mixed precision is used, this is the same as `Layer.compute_dtype`, the
| dtype of the layer's computations.
|
| dtype_policy
| The dtype policy associated with this layer.
|
| This is an instance of a `tf.keras.mixed_precision.Policy`.
|
| dynamic
| Whether the layer is dynamic (eager-only); set in the constructor.
|
| inbound_nodes
| Return Functional API nodes upstream of this layer.
|
| input
| Retrieves the input tensor(s) of a layer.
|
| Only applicable if the layer has exactly one input,
| i.e. if it is connected to one incoming layer.
|
| Returns:
| Input tensor or list of input tensors.
|
| Raises:
| RuntimeError: If called in Eager mode.
| AttributeError: If no inbound nodes are found.
|
| input_mask
| Retrieves the input mask tensor(s) of a layer.
|
| Only applicable if the layer has exactly one inbound node,
| i.e. if it is connected to one incoming layer.
|
| Returns:
| Input mask tensor (potentially None) or list of input
| mask tensors.
|
| Raises:
| AttributeError: if the layer is connected to
| more than one incoming layers.
|
| input_shape
| Retrieves the input shape(s) of a layer.
|
| Only applicable if the layer has exactly one input,
| i.e. if it is connected to one incoming layer, or if all inputs
| have the same shape.
|
| Returns:
| Input shape, as an integer shape tuple
| (or list of shape tuples, one tuple per input tensor).
|
| Raises:
| AttributeError: if the layer has no defined input_shape.
| RuntimeError: if called in Eager mode.
|
| losses
| List of losses added using the `add_loss()` API.
|
| Variable regularization tensors are created when this property is
| accessed, so it is eager safe: accessing `losses` under a
| `tf.GradientTape` will propagate gradients back to the corresponding
| variables.
|
| Examples:
|
| >>> class MyLayer(tf.keras.layers.Layer):
| ... def call(self, inputs):
| ... self.add_loss(tf.abs(tf.reduce_mean(inputs)))
| ... return inputs
| >>> l = MyLayer()
| >>> l(np.ones((10, 1)))
| >>> l.losses
| [1.0]
|
| >>> inputs = tf.keras.Input(shape=(10,))
| >>> x = tf.keras.layers.Dense(10)(inputs)
| >>> outputs = tf.keras.layers.Dense(1)(x)
| >>> model = tf.keras.Model(inputs, outputs)
| >>> # Activity regularization.
| >>> len(model.losses)
| 0
| >>> model.add_loss(tf.abs(tf.reduce_mean(x)))
| >>> len(model.losses)
| 1
|
| >>> inputs = tf.keras.Input(shape=(10,))
| >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')
| >>> x = d(inputs)
| >>> outputs = tf.keras.layers.Dense(1)(x)
| >>> model = tf.keras.Model(inputs, outputs)
| >>> # Weight regularization.
| >>> model.add_loss(lambda: tf.reduce_mean(d.kernel))
| >>> model.losses
| [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
|
| Returns:
| A list of tensors.
|
| name
| Name of the layer (string), set in the constructor.
|
| non_trainable_variables
| Sequence of non-trainable variables owned by this module and its submodules.
|
| Note: this method uses reflection to find variables on the current instance
| and submodules. For performance reasons you may wish to cache the result
| of calling this method if you don't expect the return value to change.
|
| Returns:
| A sequence of variables for the current module (sorted by attribute
| name) followed by variables from all submodules recursively (breadth
| first).
|
| outbound_nodes
| Return Functional API nodes downstream of this layer.
|
| output
| Retrieves the output tensor(s) of a layer.
|
| Only applicable if the layer has exactly one output,
| i.e. if it is connected to one incoming layer.
|
| Returns:
| Output tensor or list of output tensors.
|
| Raises:
| AttributeError: if the layer is connected to more than one incoming
| layers.
| RuntimeError: if called in Eager mode.
|
| output_mask
| Retrieves the output mask tensor(s) of a layer.
|
| Only applicable if the layer has exactly one inbound node,
| i.e. if it is connected to one incoming layer.
|
| Returns:
| Output mask tensor (potentially None) or list of output
| mask tensors.
|
| Raises:
| AttributeError: if the layer is connected to
| more than one incoming layers.
|
| output_shape
| Retrieves the output shape(s) of a layer.
|
| Only applicable if the layer has one output,
| or if all outputs have the same shape.
|
| Returns:
| Output shape, as an integer shape tuple
| (or list of shape tuples, one tuple per output tensor).
|
| Raises:
| AttributeError: if the layer has no defined output shape.
| RuntimeError: if called in Eager mode.
|
| trainable_variables
| Sequence of trainable variables owned by this module and its submodules.
|
| Note: this method uses reflection to find variables on the current instance
| and submodules. For performance reasons you may wish to cache the result
| of calling this method if you don't expect the return value to change.
|
| Returns:
| A sequence of variables for the current module (sorted by attribute
| name) followed by variables from all submodules recursively (breadth
| first).
|
| updates
|
| variable_dtype
| Alias of `Layer.dtype`, the dtype of the weights.
|
| variables
| Returns the list of all layer variables/weights.
|
| Alias of `self.weights`.
|
| Note: This will not track the weights of nested `tf.Modules` that are
| not themselves Keras layers.
|
| Returns:
| A list of variables.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from keras.src.engine.base_layer.Layer:
|
| activity_regularizer
| Optional regularizer function for the output of this layer.
|
| input_spec
| `InputSpec` instance(s) describing the input format for this layer.
|
| When you create a layer subclass, you can set `self.input_spec` to
| enable the layer to run input compatibility checks when it is called.
| Consider a `Conv2D` layer: it can only be called on a single input
| tensor of rank 4. As such, you can set, in `__init__()`:
|
| ```python
| self.input_spec = tf.keras.layers.InputSpec(ndim=4)
| ```
|
| Now, if you try to call the layer on an input that isn't rank 4
| (for instance, an input of shape `(2,)`, it will raise a
| nicely-formatted error:
|
| ```
| ValueError: Input 0 of layer conv2d is incompatible with the layer:
| expected ndim=4, found ndim=1. Full shape received: [2]
| ```
|
| Input checks that can be specified via `input_spec` include:
| - Structure (e.g. a single input, a list of 2 inputs, etc)
| - Shape
| - Rank (ndim)
| - Dtype
|
| For more information, see `tf.keras.layers.InputSpec`.
|
| Returns:
| A `tf.keras.layers.InputSpec` instance, or nested structure thereof.
|
| stateful
|
| supports_masking
| Whether this layer supports computing a mask using `compute_mask`.
|
| trainable
|
| ----------------------------------------------------------------------
| Class methods inherited from tensorflow.python.module.module.Module:
|
| with_name_scope(method) from builtins.type
| Decorator to automatically enter the module name scope.
|
| >>> class MyModule(tf.Module):
| ... @tf.Module.with_name_scope
| ... def __call__(self, x):
| ... if not hasattr(self, 'w'):
| ... self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
| ... return tf.matmul(x, self.w)
|
| Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
| names included the module name:
|
| >>> mod = MyModule()
| >>> mod(tf.ones([1, 2]))
| <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
| >>> mod.w
| <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
| numpy=..., dtype=float32)>
|
| Args:
| method: The method to wrap.
|
| Returns:
| The original method wrapped such that it enters the module's name scope.
|
| ----------------------------------------------------------------------
| Readonly properties inherited from tensorflow.python.module.module.Module:
|
| name_scope
| Returns a `tf.name_scope` instance for this class.
|
| submodules
| Sequence of all sub-modules.
|
| Submodules are modules which are properties of this module, or found as
| properties of modules which are properties of this module (and so on).
|
| >>> a = tf.Module()
| >>> b = tf.Module()
| >>> c = tf.Module()
| >>> a.b = b
| >>> b.c = c
| >>> list(a.submodules) == [b, c]
| True
| >>> list(b.submodules) == [c]
| True
| >>> list(c.submodules) == []
| True
|
| Returns:
| A sequence of all submodules.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from tensorflow.python.trackable.base.Trackable:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
12. 使用特征子集
前面的例子没有指定特征,所以所有的列都被用作输入特征(除了标签)。下面的例子展示了如何指定输入特征。
# 创建特征对象 feature_1 和 feature_2
feature_1 = tfdf.keras.FeatureUsage(name="bill_length_mm")
feature_2 = tfdf.keras.FeatureUsage(name="island")
# 将特征对象添加到特征列表 all_features 中
all_features = [feature_1, feature_2]
# 注意:该模型仅使用了两个特征进行训练,因此它的性能不如使用所有特征训练的模型好。
# 创建梯度提升树模型对象 model_2
model_2 = tfdf.keras.GradientBoostedTreesModel(
features=all_features, exclude_non_specified_features=True)
# 编译模型,指定评估指标为准确率
model_2.compile(metrics=["accuracy"])
# 使用训练数据集 train_ds 进行训练,并使用验证数据集 test_ds 进行验证
model_2.fit(train_ds, validation_data=test_ds)
# 打印模型在测试数据集上的评估结果,以字典形式返回
print(model_2.evaluate(test_ds, return_dict=True))
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpoow4zpd8 as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.144474. Found 239 examples.
Reading validation dataset...
[WARNING 23-08-16 11:05:24.8447 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:24.8447 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:24.8447 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
Num validation examples: tf.Tensor(105, shape=(), dtype=int32)
Validation dataset read in 0:00:00.205776. Found 105 examples.
Training model...
Model trained in 0:00:00.610888
Compiling model...
Model compiled.
[INFO 23-08-16 11:05:25.7975 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpoow4zpd8/model/ with prefix 730fc7c477154271
[INFO 23-08-16 11:05:25.8146 UTC decision_forest.cc:660] Model loaded with 168 root(s), 5352 node(s), and 2 input feature(s).
[INFO 23-08-16 11:05:25.8146 UTC kernel.cc:1075] Use fast generic engine
1/1 [==============================] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9810
1/1 [==============================] - 0s 87ms/step - loss: 0.0000e+00 - accuracy: 0.9810
{'loss': 0.0, 'accuracy': 0.9809523820877075}
**注意:**正如预期的那样,准确率低于以前。
TF-DF为每个特征附加了一个语义。这个语义控制了模型如何使用该特征。目前支持以下语义:
- 数值型:通常用于具有完全排序的数量或计数。例如,一个人的年龄或一个袋子中的物品数量。可以是浮点数或整数。缺失值用浮点数(Nan)或空稀疏张量表示。
- 分类型:通常用于有限可能值集合中的类型/类别,没有排序。例如,集合{RED, BLUE, GREEN}中的颜色RED。可以是字符串或整数。缺失值表示为空字符串"",值为-2或空稀疏张量。
- 分类集合型:一组分类值。非常适合表示分词文本。可以是字符串或整数,存储在稀疏张量或不规则张量(推荐)中。每个项的顺序/索引无关紧要。
如果未指定语义,则从表示类型中推断语义,并在训练日志中显示:
- int、float(密集或稀疏)→ 数值型语义。
- str(密集或稀疏)→ 分类型语义。
- int、str(不规则)→ 分类集合型语义。
在某些情况下,推断的语义是错误的。例如:将枚举存储为整数的情况下,语义上是分类型的,但会被检测为数值型。在这种情况下,应在输入中指定语义参数。成人数据集的education_num
字段就是一个经典例子。
该数据集不包含这样的特征。然而,为了演示,我们将使模型将year
视为分类特征:
# 设置单元格高度为300
%set_cell_height 300
# 创建一个特征使用对象,表示一个分类特征"year"
feature_1 = tfdf.keras.FeatureUsage(name="year", semantic=tfdf.keras.FeatureSemantic.CATEGORICAL)
# 创建一个特征使用对象,表示一个数值特征"bill_length_mm"
feature_2 = tfdf.keras.FeatureUsage(name="bill_length_mm")
# 创建一个特征使用对象,表示一个分类特征"sex"
feature_3 = tfdf.keras.FeatureUsage(name="sex")
# 将所有特征使用对象放入列表中
all_features = [feature_1, feature_2, feature_3]
# 创建一个梯度提升树模型对象,指定使用的特征为all_features,排除未指定的特征
model_3 = tfdf.keras.GradientBoostedTreesModel(features=all_features, exclude_non_specified_features=True)
# 编译模型,指定评估指标为准确率
model_3.compile(metrics=["accuracy"])
# 使用训练数据集进行模型训练,并使用验证数据集进行模型验证
model_3.fit(train_ds, validation_data=test_ds)
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpg9srb1ip as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.143701. Found 239 examples.
Reading validation dataset...
[WARNING 23-08-16 11:05:26.3095 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:26.3095 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:26.3095 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
Num validation examples: tf.Tensor(105, shape=(), dtype=int32)
Validation dataset read in 0:00:00.152938. Found 105 examples.
Training model...
Model trained in 0:00:00.267350
Compiling model...
Model compiled.
[INFO 23-08-16 11:05:26.8771 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpg9srb1ip/model/ with prefix caa2217206e3449f
[INFO 23-08-16 11:05:26.8819 UTC decision_forest.cc:660] Model loaded with 42 root(s), 1322 node(s), and 3 input feature(s).
[INFO 23-08-16 11:05:26.8819 UTC kernel.cc:1075] Use fast generic engine
<keras.src.callbacks.History at 0x7f823404dc70>
请注意year
在CATEGORICAL特征列表中(与第一次运行不同)。
13. 超参数
超参数是训练算法的参数,会影响最终模型的质量。它们在模型类的构造函数中指定。可以使用问号 colab命令(例如?tfdf.keras.GradientBoostedTreesModel
)查看超参数列表。
或者,您可以在TensorFlow决策森林Github或Yggdrasil决策森林文档中找到它们。
每个算法的默认超参数大致与初始发表的论文相匹配。为了确保一致性,默认情况下始终禁用新功能及其匹配的超参数。这就是为什么调整超参数是一个好主意的原因。
# 创建一个梯度提升树模型,使用BEST_FIRST_GLOBAL作为生长策略,最大深度为8,共有500棵树
model_6 = tfdf.keras.GradientBoostedTreesModel(
num_trees=500, growing_strategy="BEST_FIRST_GLOBAL", max_depth=8)
# 使用训练数据集来训练模型
model_6.fit(train_ds)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmp3ys0wqar as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.173451. Found 239 examples.
Training model...
[WARNING 23-08-16 11:05:27.1239 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:27.1240 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:27.1240 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
[INFO 23-08-16 11:05:33.8276 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmp3ys0wqar/model/ with prefix 10ca7d5d94bb4882
Model trained in 0:00:06.802956
Compiling model...
Model compiled.
[INFO 23-08-16 11:05:34.0959 UTC decision_forest.cc:660] Model loaded with 1500 root(s), 86196 node(s), and 7 input feature(s).
[INFO 23-08-16 11:05:34.0960 UTC abstract_model.cc:1311] Engine "GradientBoostedTreesGeneric" built
[INFO 23-08-16 11:05:34.0960 UTC kernel.cc:1075] Use fast generic engine
<keras.src.callbacks.History at 0x7f8234107490>
# 创建一个更复杂但可能更准确的模型
# 使用GradientBoostedTreesModel创建一个模型,参数如下:
# num_trees=500:使用500棵树来构建模型
# growing_strategy="BEST_FIRST_GLOBAL":使用最佳优先全局生长策略,即每次选择最佳的特征进行分裂
# max_depth=8:每棵树的最大深度为8
# split_axis="SPARSE_OBLIQUE":使用稀疏斜分裂轴,即使用斜线进行特征分裂
# categorical_algorithm="RANDOM":对于分类特征,使用随机算法进行处理
model_7 = tfdf.keras.GradientBoostedTreesModel(
num_trees=500,
growing_strategy="BEST_FIRST_GLOBAL",
max_depth=8,
split_axis="SPARSE_OBLIQUE",
categorical_algorithm="RANDOM",
)
# 使用训练数据集train_ds对模型进行训练
model_7.fit(train_ds)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpkpibv70a as temporary training directory
Reading training dataset...
[WARNING 23-08-16 11:05:34.2860 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:34.2860 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:34.2860 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
WARNING:tensorflow:5 out of the last 5 calls to <function CoreModel._consumes_training_examples_until_eof at 0x7f823cda3310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function CoreModel._consumes_training_examples_until_eof at 0x7f823cda3310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Training dataset read in 0:00:00.171961. Found 239 examples.
Training model...
[INFO 23-08-16 11:05:42.9485 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpkpibv70a/model/ with prefix c697f69f5a7e4d74
Model trained in 0:00:08.763151
Compiling model...
WARNING:tensorflow:5 out of the last 5 calls to <function InferenceCoreModel.make_predict_function.<locals>.predict_function_trained at 0x7f823050ec10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
[INFO 23-08-16 11:05:43.2128 UTC decision_forest.cc:660] Model loaded with 1500 root(s), 85322 node(s), and 7 input feature(s).
[INFO 23-08-16 11:05:43.2128 UTC kernel.cc:1075] Use fast generic engine
WARNING:tensorflow:5 out of the last 5 calls to <function InferenceCoreModel.make_predict_function.<locals>.predict_function_trained at 0x7f823050ec10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Model compiled.
<keras.src.callbacks.History at 0x7f8230484760>
随着新的训练方法的发布和实施,超参数的组合可以出现比默认参数好或几乎总是更好的情况。为了避免改变默认的超参数值,这些好的组合被索引并作为超参数模板提供。
例如,benchmark_rank1
模板是我们内部基准测试中最佳的组合。这些模板被版本化,以确保训练配置的稳定性,例如benchmark_rank1@v1
。
# 导入所需的库
import tensorflow_decision_forests as tfdf
# 创建一个使用预定义超参数模板的梯度提升树模型
# 使用"benchmark_rank1"超参数模板,该模板是一个良好的模板选择
model_8 = tfdf.keras.GradientBoostedTreesModel(hyperparameter_template="benchmark_rank1")
# 使用训练数据集来训练模型
model_8.fit(train_ds)
Resolve hyper-parameter template "benchmark_rank1" to "benchmark_rank1@v1" -> {'growing_strategy': 'BEST_FIRST_GLOBAL', 'categorical_algorithm': 'RANDOM', 'split_axis': 'SPARSE_OBLIQUE', 'sparse_oblique_normalization': 'MIN_MAX', 'sparse_oblique_num_projections_exponent': 1.0}.
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpzjvgcmpm as temporary training directory
Reading training dataset...
WARNING:tensorflow:6 out of the last 6 calls to <function CoreModel._consumes_training_examples_until_eof at 0x7f823cda3310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
[WARNING 23-08-16 11:05:43.4453 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:43.4453 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:43.4453 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
WARNING:tensorflow:6 out of the last 6 calls to <function CoreModel._consumes_training_examples_until_eof at 0x7f823cda3310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Training dataset read in 0:00:00.169369. Found 239 examples.
Training model...
Model trained in 0:00:03.481820
Compiling model...
[INFO 23-08-16 11:05:46.9935 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpzjvgcmpm/model/ with prefix 5763241a118b4af3
[INFO 23-08-16 11:05:47.0978 UTC decision_forest.cc:660] Model loaded with 900 root(s), 35042 node(s), and 7 input feature(s).
[INFO 23-08-16 11:05:47.0978 UTC abstract_model.cc:1311] Engine "GradientBoostedTreesGeneric" built
[INFO 23-08-16 11:05:47.0978 UTC kernel.cc:1075] Use fast generic engine
WARNING:tensorflow:6 out of the last 6 calls to <function InferenceCoreModel.make_predict_function.<locals>.predict_function_trained at 0x7f82304539d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function InferenceCoreModel.make_predict_function.<locals>.predict_function_trained at 0x7f82304539d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Model compiled.
<keras.src.callbacks.History at 0x7f82303c9430>
可用的模板可通过predefined_hyperparameters
进行访问。请注意,即使名称相似,不同的学习算法也有不同的模板。
# 导入tfdf库
import tensorflow_decision_forests as tfdf
# 打印梯度提升树模型的预定义超参数模板
print(tfdf.keras.GradientBoostedTreesModel.predefined_hyperparameters())
[HyperParameterTemplate(name='better_default', version=1, parameters={'growing_strategy': 'BEST_FIRST_GLOBAL'}, description='A configuration that is generally better than the default parameters without being more expensive.'), HyperParameterTemplate(name='benchmark_rank1', version=1, parameters={'growing_strategy': 'BEST_FIRST_GLOBAL', 'categorical_algorithm': 'RANDOM', 'split_axis': 'SPARSE_OBLIQUE', 'sparse_oblique_normalization': 'MIN_MAX', 'sparse_oblique_num_projections_exponent': 1.0}, description='Top ranking hyper-parameters on our benchmark slightly modified to run in reasonable time.')]
14. 特征预处理
有时需要对特征进行预处理,以便处理具有复杂结构的信号,规范化模型或应用迁移学习。可以通过以下三种方式进行预处理:
-
在Pandas数据框上进行预处理。这种解决方案易于实现,通常适用于实验。但是,通过
model.save()
无法导出预处理逻辑。 -
Keras预处理:虽然比前一种解决方案更复杂,但Keras预处理已打包在模型中。
-
TensorFlow特征列:此API是TF Estimator库(!= Keras)的一部分,并计划停用。在使用现有预处理代码时,此解决方案非常有趣。
注意:使用TensorFlow Hub预训练的嵌入通常是使用TF-DF处理文本和图像的好方法。例如,hub.KerasLayer("https://tfhub.dev/google/nnlm-en-dim128/2")
。有关更多详细信息,请参阅中级教程。
在下一个示例中,将body_mass_g
特征预处理为body_mass_kg = body_mass_g / 1000
。bill_length_mm
不经过预处理。请注意,这种单调转换通常不会对决策森林模型产生影响。
# 定义输入层,shape为(1,),名称为"body_mass_g"
body_mass_g = tf.keras.layers.Input(shape=(1,), name="body_mass_g")
# 将"body_mass_g"除以1000,得到"body_mass_kg"
body_mass_kg = body_mass_g / 1000.0
# 定义输入层,shape为(1,),名称为"bill_length_mm"
bill_length_mm = tf.keras.layers.Input(shape=(1,), name="bill_length_mm")
# 将输入层封装成字典形式,键为输入层的名称,值为输入层本身
raw_inputs = {"body_mass_g": body_mass_g, "bill_length_mm": bill_length_mm}
# 将处理后的输入层封装成字典形式,键为处理后的输入层的名称,值为输入层本身
processed_inputs = {"body_mass_kg": body_mass_kg, "bill_length_mm": bill_length_mm}
# 创建一个包含预处理逻辑的模型,输入为raw_inputs,输出为processed_inputs
preprocessor = tf.keras.Model(inputs=raw_inputs, outputs=processed_inputs)
# 创建一个包含预处理逻辑和决策森林的模型,预处理逻辑为preprocessor
model_4 = tfdf.keras.RandomForestModel(preprocessing=preprocessor)
# 使用训练数据集训练模型
model_4.fit(train_ds)
# 打印模型的摘要信息
model_4.summary()
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpmthx2t9p as temporary training directory
Reading training dataset...
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/functional.py:639: UserWarning: Input dict contained keys ['island', 'bill_depth_mm', 'flipper_length_mm', 'sex', 'year'] which did not match any model input. They will be ignored by the model.
inputs = self._flatten_to_reference_inputs(inputs)
Training dataset read in 0:00:00.226996. Found 239 examples.
Training model...
Model trained in 0:00:00.072877
Compiling model...
Model compiled.
WARNING:tensorflow:5 out of the last 12 calls to <function InferenceCoreModel.yggdrasil_model_path_tensor at 0x7f82302b1f70> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
[INFO 23-08-16 11:05:47.5902 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpmthx2t9p/model/ with prefix 124e9d64a10e49f0
[INFO 23-08-16 11:05:47.6086 UTC decision_forest.cc:660] Model loaded with 300 root(s), 6310 node(s), and 2 input feature(s).
[INFO 23-08-16 11:05:47.6086 UTC kernel.cc:1075] Use fast generic engine
WARNING:tensorflow:5 out of the last 12 calls to <function InferenceCoreModel.yggdrasil_model_path_tensor at 0x7f82302b1f70> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Model: "random_forest_model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
model (Functional) {'body_mass_kg': (None, 0
1),
'bill_length_mm': (Non
e, 1)}
=================================================================
Total params: 1 (1.00 Byte)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 1 (1.00 Byte)
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"
Input Features (2):
bill_length_mm
body_mass_kg
No weights
Variable Importance: INV_MEAN_MIN_DEPTH:
1. "bill_length_mm" 0.866916 ################
2. "body_mass_kg" 0.488533
Variable Importance: NUM_AS_ROOT:
1. "bill_length_mm" 263.000000 ################
2. "body_mass_kg" 37.000000
Variable Importance: NUM_NODES:
1. "bill_length_mm" 1537.000000 ################
2. "body_mass_kg" 1468.000000
Variable Importance: SUM_SCORE:
1. "bill_length_mm" 41227.008434 ################
2. "body_mass_kg" 27680.406197
Winner takes all: true
Out-of-bag evaluation: accuracy:0.920502 logloss:0.636824
Number of trees: 300
Total number of nodes: 6310
Number of nodes by tree:
Count: 300 Average: 21.0333 StdDev: 3.13882
Min: 11 Max: 29 Ignored: 0
----------------------------------------------
[ 11, 12) 1 0.33% 0.33%
[ 12, 13) 0 0.00% 0.33%
[ 13, 14) 1 0.33% 0.67%
[ 14, 15) 0 0.00% 0.67%
[ 15, 16) 15 5.00% 5.67% ##
[ 16, 17) 0 0.00% 5.67%
[ 17, 18) 35 11.67% 17.33% ####
[ 18, 19) 0 0.00% 17.33%
[ 19, 20) 56 18.67% 36.00% #######
[ 20, 21) 0 0.00% 36.00%
[ 21, 22) 79 26.33% 62.33% ##########
[ 22, 23) 0 0.00% 62.33%
[ 23, 24) 58 19.33% 81.67% #######
[ 24, 25) 0 0.00% 81.67%
[ 25, 26) 40 13.33% 95.00% #####
[ 26, 27) 0 0.00% 95.00%
[ 27, 28) 13 4.33% 99.33% ##
[ 28, 29) 0 0.00% 99.33%
[ 29, 29] 2 0.67% 100.00%
Depth by leafs:
Count: 3305 Average: 4.01755 StdDev: 1.39146
Min: 1 Max: 8 Ignored: 0
----------------------------------------------
[ 1, 2) 20 0.61% 0.61%
[ 2, 3) 368 11.13% 11.74% ####
[ 3, 4) 918 27.78% 39.52% #########
[ 4, 5) 973 29.44% 68.96% ##########
[ 5, 6) 517 15.64% 84.60% #####
[ 6, 7) 318 9.62% 94.22% ###
[ 7, 8) 145 4.39% 98.61% #
[ 8, 8] 46 1.39% 100.00%
Number of training obs by leaf:
Count: 3305 Average: 21.6944 StdDev: 26.3178
Min: 5 Max: 107 Ignored: 0
----------------------------------------------
[ 5, 10) 2102 63.60% 63.60% ##########
[ 10, 15) 237 7.17% 70.77% #
[ 15, 20) 54 1.63% 72.41%
[ 20, 25) 17 0.51% 72.92%
[ 25, 30) 53 1.60% 74.52%
[ 30, 35) 74 2.24% 76.76%
[ 35, 41) 99 3.00% 79.76%
[ 41, 46) 58 1.75% 81.51%
[ 46, 51) 23 0.70% 82.21%
[ 51, 56) 18 0.54% 82.75%
[ 56, 61) 58 1.75% 84.51%
[ 61, 66) 70 2.12% 86.63%
[ 66, 71) 102 3.09% 89.71%
[ 71, 77) 109 3.30% 93.01% #
[ 77, 82) 76 2.30% 95.31%
[ 82, 87) 70 2.12% 97.43%
[ 87, 92) 40 1.21% 98.64%
[ 92, 97) 23 0.70% 99.33%
[ 97, 102) 16 0.48% 99.82%
[ 102, 107] 6 0.18% 100.00%
Attribute in nodes:
1537 : bill_length_mm [NUMERICAL]
1468 : body_mass_kg [NUMERICAL]
Attribute in nodes with depth <= 0:
263 : bill_length_mm [NUMERICAL]
37 : body_mass_kg [NUMERICAL]
Attribute in nodes with depth <= 1:
446 : bill_length_mm [NUMERICAL]
434 : body_mass_kg [NUMERICAL]
Attribute in nodes with depth <= 2:
917 : body_mass_kg [NUMERICAL]
755 : bill_length_mm [NUMERICAL]
Attribute in nodes with depth <= 3:
1195 : body_mass_kg [NUMERICAL]
1143 : bill_length_mm [NUMERICAL]
Attribute in nodes with depth <= 5:
1477 : bill_length_mm [NUMERICAL]
1421 : body_mass_kg [NUMERICAL]
Condition type in nodes:
3005 : HigherCondition
Condition type in nodes with depth <= 0:
300 : HigherCondition
Condition type in nodes with depth <= 1:
880 : HigherCondition
Condition type in nodes with depth <= 2:
1672 : HigherCondition
Condition type in nodes with depth <= 3:
2338 : HigherCondition
Condition type in nodes with depth <= 5:
2898 : HigherCondition
Node format: NOT_SET
Training OOB:
trees: 1, Out-of-bag evaluation: accuracy:0.875 logloss:4.50546
trees: 13, Out-of-bag evaluation: accuracy:0.890295 logloss:2.35926
trees: 23, Out-of-bag evaluation: accuracy:0.891213 logloss:1.76382
trees: 35, Out-of-bag evaluation: accuracy:0.903766 logloss:1.61533
trees: 46, Out-of-bag evaluation: accuracy:0.912134 logloss:1.61544
trees: 59, Out-of-bag evaluation: accuracy:0.912134 logloss:1.33186
trees: 69, Out-of-bag evaluation: accuracy:0.916318 logloss:1.19735
trees: 80, Out-of-bag evaluation: accuracy:0.920502 logloss:1.20323
trees: 90, Out-of-bag evaluation: accuracy:0.916318 logloss:1.06613
trees: 102, Out-of-bag evaluation: accuracy:0.916318 logloss:0.920117
trees: 112, Out-of-bag evaluation: accuracy:0.916318 logloss:0.919398
trees: 122, Out-of-bag evaluation: accuracy:0.916318 logloss:0.918544
trees: 132, Out-of-bag evaluation: accuracy:0.916318 logloss:0.917733
trees: 143, Out-of-bag evaluation: accuracy:0.920502 logloss:0.916464
trees: 154, Out-of-bag evaluation: accuracy:0.920502 logloss:0.916065
trees: 167, Out-of-bag evaluation: accuracy:0.920502 logloss:0.915384
trees: 178, Out-of-bag evaluation: accuracy:0.920502 logloss:0.781669
trees: 188, Out-of-bag evaluation: accuracy:0.924686 logloss:0.782319
trees: 200, Out-of-bag evaluation: accuracy:0.920502 logloss:0.774758
trees: 210, Out-of-bag evaluation: accuracy:0.924686 logloss:0.774025
trees: 221, Out-of-bag evaluation: accuracy:0.924686 logloss:0.770531
trees: 231, Out-of-bag evaluation: accuracy:0.924686 logloss:0.77066
trees: 241, Out-of-bag evaluation: accuracy:0.924686 logloss:0.767545
trees: 251, Out-of-bag evaluation: accuracy:0.924686 logloss:0.767962
trees: 261, Out-of-bag evaluation: accuracy:0.924686 logloss:0.767063
trees: 271, Out-of-bag evaluation: accuracy:0.924686 logloss:0.767585
trees: 281, Out-of-bag evaluation: accuracy:0.924686 logloss:0.766893
trees: 292, Out-of-bag evaluation: accuracy:0.924686 logloss:0.634927
trees: 300, Out-of-bag evaluation: accuracy:0.920502 logloss:0.636824
以下示例使用TensorFlow特征列重新实现了相同的逻辑。
# 定义一个函数 g_to_kg,用于将克转换为千克
def g_to_kg(x):
return x / 1000
# 定义特征列,包括 "body_mass_g" 和 "bill_length_mm"
feature_columns = [
tf.feature_column.numeric_column("body_mass_g", normalizer_fn=g_to_kg), # 对 "body_mass_g" 应用 g_to_kg 函数进行归一化处理
tf.feature_column.numeric_column("bill_length_mm"),
]
# 创建一个预处理层,将特征列应用到输入数据上
preprocessing = tf.keras.layers.DenseFeatures(feature_columns)
# 创建一个随机森林模型,将预处理层作为输入
model_5 = tfdf.keras.RandomForestModel(preprocessing=preprocessing)
# 使用训练数据集对模型进行训练
model_5.fit(train_ds)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10582/2850711544.py:5: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10582/2850711544.py:5: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpiqyvbd4a as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.163336. Found 239 examples.
Training model...
Model trained in 0:00:00.050388
Compiling model...
Model compiled.
WARNING:tensorflow:6 out of the last 13 calls to <function InferenceCoreModel.yggdrasil_model_path_tensor at 0x7f82301805e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
[INFO 23-08-16 11:05:47.9585 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpiqyvbd4a/model/ with prefix b3488dac1bd3468f
[INFO 23-08-16 11:05:47.9775 UTC decision_forest.cc:660] Model loaded with 300 root(s), 6310 node(s), and 2 input feature(s).
[INFO 23-08-16 11:05:47.9776 UTC kernel.cc:1075] Use fast generic engine
WARNING:tensorflow:6 out of the last 13 calls to <function InferenceCoreModel.yggdrasil_model_path_tensor at 0x7f82301805e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
<keras.src.callbacks.History at 0x7f82301a90a0>
15. 训练回归模型
前面的例子训练了一个分类模型(TF-DF不区分二元分类和多元分类)。在下一个例子中,我们将在鲍鱼数据集上训练一个回归模型。该数据集的目标是预测鲍鱼的贝壳环数。
注意: CSV文件是通过附加UCI的标题和数据文件组装而成的。没有应用任何预处理。
# 下载数据集
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/abalone_raw.csv -O /tmp/abalone.csv
# 读取CSV文件并将其存储在名为dataset_df的DataFrame中
dataset_df = pd.read_csv("/tmp/abalone.csv")
# 打印DataFrame的前3行数据
print(dataset_df.head(3))
Type LongestShell Diameter Height WholeWeight ShuckedWeight \
0 M 0.455 0.365 0.095 0.5140 0.2245
1 M 0.350 0.265 0.090 0.2255 0.0995
2 F 0.530 0.420 0.135 0.6770 0.2565
VisceraWeight ShellWeight Rings
0 0.1010 0.15 15
1 0.0485 0.07 7
2 0.1415 0.21 9
# 将数据集分为训练集和测试集。
train_ds_pd, test_ds_pd = split_dataset(dataset_df)
# 输出训练集和测试集的样本数量。
print("{} 个样本用于训练,{} 个样本用于测试。".format(len(train_ds_pd), len(test_ds_pd)))
# 定义标签列的名称。
label = "Rings"
# 将 Pandas 数据框转换为 TensorFlow 数据集,用于回归任务。
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label, task=tfdf.keras.Task.REGRESSION)
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label, task=tfdf.keras.Task.REGRESSION)
2885 examples in training, 1292 examples for testing.
# 设置单元格高度为300
# 配置模型。
# 创建一个随机森林模型,用于回归任务。
model_7 = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.REGRESSION)
# 训练模型。
# 使用训练数据集train_ds对模型进行训练。
model_7.fit(train_ds)
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpr6p677s9 as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.210193. Found 2885 examples.
Training model...
[INFO 23-08-16 11:05:49.2476 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpr6p677s9/model/ with prefix 8e87bf0ca0c24b13
Model trained in 0:00:01.408931
Compiling model...
[INFO 23-08-16 11:05:50.0462 UTC decision_forest.cc:660] Model loaded with 300 root(s), 260570 node(s), and 8 input feature(s).
[INFO 23-08-16 11:05:50.0463 UTC kernel.cc:1075] Use fast generic engine
Model compiled.
<keras.src.callbacks.History at 0x7f823b7d6100>
# 评估模型在测试数据集上的性能。
model_7.compile(metrics=["mse"]) # 编译模型,指定评估指标为均方误差(Mean Squared Error)
evaluation = model_7.evaluate(test_ds, return_dict=True) # 在测试数据集上评估模型,并返回评估结果字典
print(evaluation) # 打印评估结果字典
print()
print(f"MSE: {evaluation['mse']}") # 打印均方误差(Mean Squared Error)
print(f"RMSE: {math.sqrt(evaluation['mse'])}") # 打印均方根误差(Root Mean Squared Error)
WARNING:tensorflow:5 out of the last 5 calls to <function InferenceCoreModel.make_test_function.<locals>.test_function at 0x7f82300f5310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function InferenceCoreModel.make_test_function.<locals>.test_function at 0x7f82300f5310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
1/2 [==============>...............] - ETA: 0s - loss: 0.0000e+00 - mse: 4.7546
2/2 [==============================] - 0s 15ms/step - loss: 0.0000e+00 - mse: 4.6777
{'loss': 0.0, 'mse': 4.677670955657959}
MSE: 4.677670955657959
RMSE: 2.1627923977252093
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)