作者: Khalid Salama
创建日期 2022/01/25
最后修改日期 2022/01/25
描述: 使用 TensorFlow 决策森林进行结构化数据分类。
TensorFlow 决策森林是与 Keras API 兼容的决策森林模型的一系列最先进算法的集合。这些模型包括 随机森林、梯度提升树和 CART,可用于回归、分类和排序任务。有关 TensorFlow 决策森林的初学者指南,请参考此教程。
本示例在结构化数据的二元分类中使用梯度提升树模型,涵盖以下场景
此示例使用 TensorFlow 2.7 或更高版本,以及 TensorFlow 决策森林,您可以使用以下命令安装:
pip install -U tensorflow_decision_forests
import math
import urllib
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_decision_forests as tfdf
本示例使用 美国人口普查收入数据集,该数据集由 加州大学欧文分校机器学习存储库提供。任务是二元分类,以确定一个人的年收入是否超过 5 万美元。
该数据集包括约 30 万个实例,包含 41 个输入特征:7 个数值特征和 34 个分类特征。
首先,我们将数据从加州大学欧文分校机器学习存储库加载到 Pandas DataFrame 中。
BASE_PATH = "https://kdd.ics.uci.edu/databases/census-income/census-income"
CSV_HEADER = [
l.decode("utf-8").split(":")[0].replace(" ", "_")
for l in urllib.request.urlopen(f"{BASE_PATH}.names")
if not l.startswith(b"|")
][2:]
CSV_HEADER.append("income_level")
train_data = pd.read_csv(f"{BASE_PATH}.data.gz", header=None, names=CSV_HEADER,)
test_data = pd.read_csv(f"{BASE_PATH}.test.gz", header=None, names=CSV_HEADER,)
在此,我们定义数据集的元数据,该元数据对于根据其类型编码输入特征非常有用。
# Target column name.
TARGET_COLUMN_NAME = "income_level"
# The labels of the target columns.
TARGET_LABELS = [" - 50000.", " 50000+."]
# Weight column name.
WEIGHT_COLUMN_NAME = "instance_weight"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = [
"age",
"wage_per_hour",
"capital_gains",
"capital_losses",
"dividends_from_stocks",
"num_persons_worked_for_employer",
"weeks_worked_in_year",
]
# Categorical features and their vocabulary lists.
CATEGORICAL_FEATURE_NAMES = [
"class_of_worker",
"detailed_industry_recode",
"detailed_occupation_recode",
"education",
"enroll_in_edu_inst_last_wk",
"marital_stat",
"major_industry_code",
"major_occupation_code",
"race",
"hispanic_origin",
"sex",
"member_of_a_labor_union",
"reason_for_unemployment",
"full_or_part_time_employment_stat",
"tax_filer_stat",
"region_of_previous_residence",
"state_of_previous_residence",
"detailed_household_and_family_stat",
"detailed_household_summary_in_household",
"migration_code-change_in_msa",
"migration_code-change_in_reg",
"migration_code-move_within_reg",
"live_in_this_house_1_year_ago",
"migration_prev_res_in_sunbelt",
"family_members_under_18",
"country_of_birth_father",
"country_of_birth_mother",
"country_of_birth_self",
"citizenship",
"own_business_or_self_employed",
"fill_inc_questionnaire_for_veteran's_admin",
"veterans_benefits",
"year",
]
现在我们执行基本的数据准备。
def prepare_dataframe(dataframe):
# Convert the target labels from string to integer.
dataframe[TARGET_COLUMN_NAME] = dataframe[TARGET_COLUMN_NAME].map(
TARGET_LABELS.index
)
# Cast the categorical features to string.
for feature_name in CATEGORICAL_FEATURE_NAMES:
dataframe[feature_name] = dataframe[feature_name].astype(str)
prepare_dataframe(train_data)
prepare_dataframe(test_data)
现在让我们显示训练和测试数据帧的形状,并显示一些实例。
print(f"Train data shape: {train_data.shape}")
print(f"Test data shape: {test_data.shape}")
print(train_data.head().T)
Train data shape: (199523, 42)
Test data shape: (99762, 42)
0 \
age 73
class_of_worker Not in universe
detailed_industry_recode 0
detailed_occupation_recode 0
education High school graduate
wage_per_hour 0
enroll_in_edu_inst_last_wk Not in universe
marital_stat Widowed
major_industry_code Not in universe or children
major_occupation_code Not in universe
race White
hispanic_origin All other
sex Female
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Not in labor force
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Nonfiler
region_of_previous_residence Not in universe
state_of_previous_residence Not in universe
detailed_household_and_family_stat Other Rel 18+ ever marr not in subfamily
detailed_household_summary_in_household Other relative of householder
instance_weight 1700.09
migration_code-change_in_msa ?
migration_code-change_in_reg ?
migration_code-move_within_reg ?
live_in_this_house_1_year_ago Not in universe under 1 year old
migration_prev_res_in_sunbelt ?
num_persons_worked_for_employer 0
family_members_under_18 Not in universe
country_of_birth_father United-States
country_of_birth_mother United-States
country_of_birth_self United-States
citizenship Native- Born in the United States
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 2
weeks_worked_in_year 0
year 95
income_level 0
1 \
age 58
class_of_worker Self-employed-not incorporated
detailed_industry_recode 4
detailed_occupation_recode 34
education Some college but no degree
wage_per_hour 0
enroll_in_edu_inst_last_wk Not in universe
marital_stat Divorced
major_industry_code Construction
major_occupation_code Precision production craft & repair
race White
hispanic_origin All other
sex Male
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Children or Armed Forces
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Head of household
region_of_previous_residence South
state_of_previous_residence Arkansas
detailed_household_and_family_stat Householder
detailed_household_summary_in_household Householder
instance_weight 1053.55
migration_code-change_in_msa MSA to MSA
migration_code-change_in_reg Same county
migration_code-move_within_reg Same county
live_in_this_house_1_year_ago No
migration_prev_res_in_sunbelt Yes
num_persons_worked_for_employer 1
family_members_under_18 Not in universe
country_of_birth_father United-States
country_of_birth_mother United-States
country_of_birth_self United-States
citizenship Native- Born in the United States
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 2
weeks_worked_in_year 52
year 94
income_level 0
2 \
age 18
class_of_worker Not in universe
detailed_industry_recode 0
detailed_occupation_recode 0
education 10th grade
wage_per_hour 0
enroll_in_edu_inst_last_wk High school
marital_stat Never married
major_industry_code Not in universe or children
major_occupation_code Not in universe
race Asian or Pacific Islander
hispanic_origin All other
sex Female
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Not in labor force
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Nonfiler
region_of_previous_residence Not in universe
state_of_previous_residence Not in universe
detailed_household_and_family_stat Child 18+ never marr Not in a subfamily
detailed_household_summary_in_household Child 18 or older
instance_weight 991.95
migration_code-change_in_msa ?
migration_code-change_in_reg ?
migration_code-move_within_reg ?
live_in_this_house_1_year_ago Not in universe under 1 year old
migration_prev_res_in_sunbelt ?
num_persons_worked_for_employer 0
family_members_under_18 Not in universe
country_of_birth_father Vietnam
country_of_birth_mother Vietnam
country_of_birth_self Vietnam
citizenship Foreign born- Not a citizen of U S
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 2
weeks_worked_in_year 0
year 95
income_level 0
3 \
age 9
class_of_worker Not in universe
detailed_industry_recode 0
detailed_occupation_recode 0
education Children
wage_per_hour 0
enroll_in_edu_inst_last_wk Not in universe
marital_stat Never married
major_industry_code Not in universe or children
major_occupation_code Not in universe
race White
hispanic_origin All other
sex Female
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Children or Armed Forces
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Nonfiler
region_of_previous_residence Not in universe
state_of_previous_residence Not in universe
detailed_household_and_family_stat Child <18 never marr not in subfamily
detailed_household_summary_in_household Child under 18 never married
instance_weight 1758.14
migration_code-change_in_msa Nonmover
migration_code-change_in_reg Nonmover
migration_code-move_within_reg Nonmover
live_in_this_house_1_year_ago Yes
migration_prev_res_in_sunbelt Not in universe
num_persons_worked_for_employer 0
family_members_under_18 Both parents present
country_of_birth_father United-States
country_of_birth_mother United-States
country_of_birth_self United-States
citizenship Native- Born in the United States
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 0
weeks_worked_in_year 0
year 94
income_level 0
4
age 10
class_of_worker Not in universe
detailed_industry_recode 0
detailed_occupation_recode 0
education Children
wage_per_hour 0
enroll_in_edu_inst_last_wk Not in universe
marital_stat Never married
major_industry_code Not in universe or children
major_occupation_code Not in universe
race White
hispanic_origin All other
sex Female
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Children or Armed Forces
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Nonfiler
region_of_previous_residence Not in universe
state_of_previous_residence Not in universe
detailed_household_and_family_stat Child <18 never marr not in subfamily
detailed_household_summary_in_household Child under 18 never married
instance_weight 1069.16
migration_code-change_in_msa Nonmover
migration_code-change_in_reg Nonmover
migration_code-move_within_reg Nonmover
live_in_this_house_1_year_ago Yes
migration_prev_res_in_sunbelt Not in universe
num_persons_worked_for_employer 0
family_members_under_18 Both parents present
country_of_birth_father United-States
country_of_birth_mother United-States
country_of_birth_self United-States
citizenship Native- Born in the United States
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 0
weeks_worked_in_year 0
year 94
income_level 0
您可以在 文档中找到梯度提升树模型的所有参数
# Maximum number of decision trees. The effective number of trained trees can be smaller if early stopping is enabled.
NUM_TREES = 250
# Minimum number of examples in a node.
MIN_EXAMPLES = 6
# Maximum depth of the tree. max_depth=1 means that all trees will be roots.
MAX_DEPTH = 5
# Ratio of the dataset (sampling without replacement) used to train individual trees for the random sampling method.
SUBSAMPLE = 0.65
# Control the sampling of the datasets used to train individual trees.
SAMPLING_METHOD = "RANDOM"
# Ratio of the training dataset used to monitor the training. Require to be >0 if early stopping is enabled.
VALIDATION_RATIO = 0.1
run_experiment()
方法负责加载训练和测试数据集、训练给定模型并评估训练后的模型。
请注意,在训练决策森林模型时,只需要一个 epoch 即可读取完整数据集。任何额外的步骤都会导致不必要的训练速度变慢。因此,run_experiment()
方法中使用默认的 num_epochs=1
。
def run_experiment(model, train_data, test_data, num_epochs=1, batch_size=None):
train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
train_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME
)
test_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
test_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME
)
model.fit(train_dataset, epochs=num_epochs, batch_size=batch_size)
_, accuracy = model.evaluate(test_dataset, verbose=0)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
您可以将语义附加到每个特征,以控制模型如何使用它。如果未指定,则语义将从表示类型中推断。建议显式指定特征使用方式,以避免推断出的语义不正确。例如,分类值标识符(整数)将被推断为数值,而它在语义上是分类的。
对于数值特征,您可以将 discretized
参数设置为应该离散数值特征的桶数。这使得训练更快,但可能会导致更差的模型。
def specify_feature_usages():
feature_usages = []
for feature_name in NUMERIC_FEATURE_NAMES:
feature_usage = tfdf.keras.FeatureUsage(
name=feature_name, semantic=tfdf.keras.FeatureSemantic.NUMERICAL
)
feature_usages.append(feature_usage)
for feature_name in CATEGORICAL_FEATURE_NAMES:
feature_usage = tfdf.keras.FeatureUsage(
name=feature_name, semantic=tfdf.keras.FeatureSemantic.CATEGORICAL
)
feature_usages.append(feature_usage)
return feature_usages
编译决策森林模型时,您只能提供额外的评估指标。损失在模型构造中指定,并且优化器与决策森林模型无关。
def create_gbt_model():
# See all the model parameters in https://tensorflowcn.cn/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModel
gbt_model = tfdf.keras.GradientBoostedTreesModel(
features=specify_feature_usages(),
exclude_non_specified_features=True,
num_trees=NUM_TREES,
max_depth=MAX_DEPTH,
min_examples=MIN_EXAMPLES,
subsample=SUBSAMPLE,
validation_ratio=VALIDATION_RATIO,
task=tfdf.keras.Task.CLASSIFICATION,
)
gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])
return gbt_model
gbt_model = create_gbt_model()
run_experiment(gbt_model, train_data, test_data)
Starting reading the dataset
200/200 [==============================] - ETA: 0s
Dataset read in 0:00:08.829036
Training model
Model trained in 0:00:48.639771
Compiling model
200/200 [==============================] - 58s 268ms/step
Test accuracy: 95.79%
model.summary()
方法将显示有关决策树模型的几种类型的信息、模型类型、任务、输入特征和特征重要性。
print(gbt_model.summary())
Model: "gradient_boosted_trees_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "GRADIENT_BOOSTED_TREES"
Task: CLASSIFICATION
Label: "__LABEL"
Input Features (40):
age
capital_gains
capital_losses
citizenship
class_of_worker
country_of_birth_father
country_of_birth_mother
country_of_birth_self
detailed_household_and_family_stat
detailed_household_summary_in_household
detailed_industry_recode
detailed_occupation_recode
dividends_from_stocks
education
enroll_in_edu_inst_last_wk
family_members_under_18
fill_inc_questionnaire_for_veteran's_admin
full_or_part_time_employment_stat
hispanic_origin
live_in_this_house_1_year_ago
major_industry_code
major_occupation_code
marital_stat
member_of_a_labor_union
migration_code-change_in_msa
migration_code-change_in_reg
migration_code-move_within_reg
migration_prev_res_in_sunbelt
num_persons_worked_for_employer
own_business_or_self_employed
race
reason_for_unemployment
region_of_previous_residence
sex
state_of_previous_residence
tax_filer_stat
veterans_benefits
wage_per_hour
weeks_worked_in_year
year
Trained with weights
Variable Importance: MEAN_MIN_DEPTH:
1. "enroll_in_edu_inst_last_wk" 3.942647 ################
2. "family_members_under_18" 3.942647 ################
3. "live_in_this_house_1_year_ago" 3.942647 ################
4. "migration_code-change_in_msa" 3.942647 ################
5. "migration_code-move_within_reg" 3.942647 ################
6. "year" 3.942647 ################
7. "__LABEL" 3.942647 ################
8. "__WEIGHTS" 3.942647 ################
9. "citizenship" 3.942137 ###############
10. "detailed_household_summary_in_household" 3.942137 ###############
11. "region_of_previous_residence" 3.942137 ###############
12. "veterans_benefits" 3.942137 ###############
13. "migration_prev_res_in_sunbelt" 3.940135 ###############
14. "migration_code-change_in_reg" 3.939926 ###############
15. "major_occupation_code" 3.937681 ###############
16. "major_industry_code" 3.933687 ###############
17. "reason_for_unemployment" 3.926320 ###############
18. "hispanic_origin" 3.900776 ###############
19. "member_of_a_labor_union" 3.894843 ###############
20. "race" 3.878617 ###############
21. "num_persons_worked_for_employer" 3.818566 ##############
22. "marital_stat" 3.795667 ##############
23. "full_or_part_time_employment_stat" 3.795431 ##############
24. "country_of_birth_mother" 3.787967 ##############
25. "tax_filer_stat" 3.784505 ##############
26. "fill_inc_questionnaire_for_veteran's_admin" 3.783607 ##############
27. "own_business_or_self_employed" 3.776398 ##############
28. "country_of_birth_father" 3.715252 #############
29. "sex" 3.708745 #############
30. "class_of_worker" 3.688424 #############
31. "weeks_worked_in_year" 3.665290 #############
32. "state_of_previous_residence" 3.657234 #############
33. "country_of_birth_self" 3.654377 #############
34. "age" 3.634295 ############
35. "wage_per_hour" 3.617817 ############
36. "detailed_household_and_family_stat" 3.594743 ############
37. "capital_losses" 3.439298 ##########
38. "dividends_from_stocks" 3.423652 ##########
39. "capital_gains" 3.222753 ########
40. "education" 3.158698 ########
41. "detailed_industry_recode" 2.981471 ######
42. "detailed_occupation_recode" 2.364817
Variable Importance: NUM_AS_ROOT:
1. "education" 33.000000 ################
2. "capital_gains" 29.000000 ##############
3. "capital_losses" 24.000000 ###########
4. "detailed_household_and_family_stat" 14.000000 ######
5. "dividends_from_stocks" 14.000000 ######
6. "wage_per_hour" 12.000000 #####
7. "country_of_birth_self" 11.000000 #####
8. "detailed_occupation_recode" 11.000000 #####
9. "weeks_worked_in_year" 11.000000 #####
10. "age" 10.000000 ####
11. "state_of_previous_residence" 10.000000 ####
12. "fill_inc_questionnaire_for_veteran's_admin" 9.000000 ####
13. "class_of_worker" 8.000000 ###
14. "full_or_part_time_employment_stat" 8.000000 ###
15. "marital_stat" 8.000000 ###
16. "own_business_or_self_employed" 8.000000 ###
17. "sex" 6.000000 ##
18. "tax_filer_stat" 5.000000 ##
19. "country_of_birth_father" 4.000000 #
20. "race" 3.000000 #
21. "detailed_industry_recode" 2.000000
22. "hispanic_origin" 2.000000
23. "country_of_birth_mother" 1.000000
24. "num_persons_worked_for_employer" 1.000000
25. "reason_for_unemployment" 1.000000
Variable Importance: NUM_NODES:
1. "detailed_occupation_recode" 785.000000 ################
2. "detailed_industry_recode" 668.000000 #############
3. "capital_gains" 275.000000 #####
4. "dividends_from_stocks" 220.000000 ####
5. "capital_losses" 197.000000 ####
6. "education" 178.000000 ###
7. "country_of_birth_mother" 128.000000 ##
8. "country_of_birth_father" 116.000000 ##
9. "age" 114.000000 ##
10. "wage_per_hour" 98.000000 #
11. "state_of_previous_residence" 95.000000 #
12. "detailed_household_and_family_stat" 78.000000 #
13. "class_of_worker" 67.000000 #
14. "country_of_birth_self" 65.000000 #
15. "sex" 65.000000 #
16. "weeks_worked_in_year" 60.000000 #
17. "tax_filer_stat" 57.000000 #
18. "num_persons_worked_for_employer" 54.000000 #
19. "own_business_or_self_employed" 30.000000
20. "marital_stat" 26.000000
21. "member_of_a_labor_union" 16.000000
22. "fill_inc_questionnaire_for_veteran's_admin" 15.000000
23. "full_or_part_time_employment_stat" 15.000000
24. "major_industry_code" 15.000000
25. "hispanic_origin" 9.000000
26. "major_occupation_code" 7.000000
27. "race" 7.000000
28. "citizenship" 1.000000
29. "detailed_household_summary_in_household" 1.000000
30. "migration_code-change_in_reg" 1.000000
31. "migration_prev_res_in_sunbelt" 1.000000
32. "reason_for_unemployment" 1.000000
33. "region_of_previous_residence" 1.000000
34. "veterans_benefits" 1.000000
Variable Importance: SUM_SCORE:
1. "detailed_occupation_recode" 15392441.075369 ################
2. "capital_gains" 5277826.822514 #####
3. "education" 4751749.289550 ####
4. "dividends_from_stocks" 3792002.951255 ###
5. "detailed_industry_recode" 2882200.882109 ##
6. "sex" 2559417.877325 ##
7. "age" 2042990.944829 ##
8. "capital_losses" 1735728.772551 #
9. "weeks_worked_in_year" 1272820.203971 #
10. "tax_filer_stat" 697890.160846
11. "num_persons_worked_for_employer" 671351.905595
12. "detailed_household_and_family_stat" 444620.829557
13. "class_of_worker" 362250.565331
14. "country_of_birth_mother" 296311.574426
15. "country_of_birth_father" 258198.889206
16. "wage_per_hour" 239764.219048
17. "state_of_previous_residence" 237687.602572
18. "country_of_birth_self" 103002.168158
19. "marital_stat" 102449.735314
20. "own_business_or_self_employed" 82938.893541
21. "fill_inc_questionnaire_for_veteran's_admin" 22692.700206
22. "full_or_part_time_employment_stat" 19078.398837
23. "major_industry_code" 18450.345505
24. "member_of_a_labor_union" 14905.360879
25. "hispanic_origin" 12602.867902
26. "major_occupation_code" 8709.665989
27. "race" 6116.282065
28. "citizenship" 3291.490393
29. "detailed_household_summary_in_household" 2733.439375
30. "veterans_benefits" 1230.940488
31. "region_of_previous_residence" 1139.240981
32. "reason_for_unemployment" 219.245124
33. "migration_code-change_in_reg" 55.806436
34. "migration_prev_res_in_sunbelt" 37.780635
Loss: BINOMIAL_LOG_LIKELIHOOD
Validation loss value: 0.228983
Number of trees per iteration: 1
Node format: NOT_SET
Number of trees: 245
Total number of nodes: 7179
Number of nodes by tree:
Count: 245 Average: 29.302 StdDev: 2.96211
Min: 17 Max: 31 Ignored: 0
----------------------------------------------
[ 17, 18) 2 0.82% 0.82%
[ 18, 19) 0 0.00% 0.82%
[ 19, 20) 3 1.22% 2.04%
[ 20, 21) 0 0.00% 2.04%
[ 21, 22) 4 1.63% 3.67%
[ 22, 23) 0 0.00% 3.67%
[ 23, 24) 15 6.12% 9.80% #
[ 24, 25) 0 0.00% 9.80%
[ 25, 26) 5 2.04% 11.84%
[ 26, 27) 0 0.00% 11.84%
[ 27, 28) 21 8.57% 20.41% #
[ 28, 29) 0 0.00% 20.41%
[ 29, 30) 39 15.92% 36.33% ###
[ 30, 31) 0 0.00% 36.33%
[ 31, 31] 156 63.67% 100.00% ##########
Depth by leafs:
Count: 3712 Average: 3.95259 StdDev: 0.249814
Min: 2 Max: 4 Ignored: 0
----------------------------------------------
[ 2, 3) 32 0.86% 0.86%
[ 3, 4) 112 3.02% 3.88%
[ 4, 4] 3568 96.12% 100.00% ##########
Number of training obs by leaf:
Count: 3712 Average: 11849.3 StdDev: 33719.3
Min: 6 Max: 179360 Ignored: 0
----------------------------------------------
[ 6, 8973) 3100 83.51% 83.51% ##########
[ 8973, 17941) 148 3.99% 87.50%
[ 17941, 26909) 79 2.13% 89.63%
[ 26909, 35877) 36 0.97% 90.60%
[ 35877, 44844) 44 1.19% 91.78%
[ 44844, 53812) 17 0.46% 92.24%
[ 53812, 62780) 20 0.54% 92.78%
[ 62780, 71748) 39 1.05% 93.83%
[ 71748, 80715) 24 0.65% 94.48%
[ 80715, 89683) 12 0.32% 94.80%
[ 89683, 98651) 22 0.59% 95.39%
[ 98651, 107619) 21 0.57% 95.96%
[ 107619, 116586) 17 0.46% 96.42%
[ 116586, 125554) 17 0.46% 96.88%
[ 125554, 134522) 13 0.35% 97.23%
[ 134522, 143490) 8 0.22% 97.44%
[ 143490, 152457) 5 0.13% 97.58%
[ 152457, 161425) 6 0.16% 97.74%
[ 161425, 170393) 15 0.40% 98.14%
[ 170393, 179360] 69 1.86% 100.00%
Attribute in nodes:
785 : detailed_occupation_recode [CATEGORICAL]
668 : detailed_industry_recode [CATEGORICAL]
275 : capital_gains [NUMERICAL]
220 : dividends_from_stocks [NUMERICAL]
197 : capital_losses [NUMERICAL]
178 : education [CATEGORICAL]
128 : country_of_birth_mother [CATEGORICAL]
116 : country_of_birth_father [CATEGORICAL]
114 : age [NUMERICAL]
98 : wage_per_hour [NUMERICAL]
95 : state_of_previous_residence [CATEGORICAL]
78 : detailed_household_and_family_stat [CATEGORICAL]
67 : class_of_worker [CATEGORICAL]
65 : sex [CATEGORICAL]
65 : country_of_birth_self [CATEGORICAL]
60 : weeks_worked_in_year [NUMERICAL]
57 : tax_filer_stat [CATEGORICAL]
54 : num_persons_worked_for_employer [NUMERICAL]
30 : own_business_or_self_employed [CATEGORICAL]
26 : marital_stat [CATEGORICAL]
16 : member_of_a_labor_union [CATEGORICAL]
15 : major_industry_code [CATEGORICAL]
15 : full_or_part_time_employment_stat [CATEGORICAL]
15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]
9 : hispanic_origin [CATEGORICAL]
7 : race [CATEGORICAL]
7 : major_occupation_code [CATEGORICAL]
1 : veterans_benefits [CATEGORICAL]
1 : region_of_previous_residence [CATEGORICAL]
1 : reason_for_unemployment [CATEGORICAL]
1 : migration_prev_res_in_sunbelt [CATEGORICAL]
1 : migration_code-change_in_reg [CATEGORICAL]
1 : detailed_household_summary_in_household [CATEGORICAL]
1 : citizenship [CATEGORICAL]
Attribute in nodes with depth <= 0:
33 : education [CATEGORICAL]
29 : capital_gains [NUMERICAL]
24 : capital_losses [NUMERICAL]
14 : dividends_from_stocks [NUMERICAL]
14 : detailed_household_and_family_stat [CATEGORICAL]
12 : wage_per_hour [NUMERICAL]
11 : weeks_worked_in_year [NUMERICAL]
11 : detailed_occupation_recode [CATEGORICAL]
11 : country_of_birth_self [CATEGORICAL]
10 : state_of_previous_residence [CATEGORICAL]
10 : age [NUMERICAL]
9 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]
8 : own_business_or_self_employed [CATEGORICAL]
8 : marital_stat [CATEGORICAL]
8 : full_or_part_time_employment_stat [CATEGORICAL]
8 : class_of_worker [CATEGORICAL]
6 : sex [CATEGORICAL]
5 : tax_filer_stat [CATEGORICAL]
4 : country_of_birth_father [CATEGORICAL]
3 : race [CATEGORICAL]
2 : hispanic_origin [CATEGORICAL]
2 : detailed_industry_recode [CATEGORICAL]
1 : reason_for_unemployment [CATEGORICAL]
1 : num_persons_worked_for_employer [NUMERICAL]
1 : country_of_birth_mother [CATEGORICAL]
Attribute in nodes with depth <= 1:
140 : detailed_occupation_recode [CATEGORICAL]
82 : capital_gains [NUMERICAL]
65 : capital_losses [NUMERICAL]
62 : education [CATEGORICAL]
59 : detailed_industry_recode [CATEGORICAL]
47 : dividends_from_stocks [NUMERICAL]
31 : wage_per_hour [NUMERICAL]
26 : detailed_household_and_family_stat [CATEGORICAL]
23 : age [NUMERICAL]
22 : state_of_previous_residence [CATEGORICAL]
21 : country_of_birth_self [CATEGORICAL]
21 : class_of_worker [CATEGORICAL]
20 : weeks_worked_in_year [NUMERICAL]
20 : sex [CATEGORICAL]
15 : country_of_birth_father [CATEGORICAL]
12 : own_business_or_self_employed [CATEGORICAL]
11 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]
10 : num_persons_worked_for_employer [NUMERICAL]
9 : tax_filer_stat [CATEGORICAL]
9 : full_or_part_time_employment_stat [CATEGORICAL]
8 : marital_stat [CATEGORICAL]
8 : country_of_birth_mother [CATEGORICAL]
6 : member_of_a_labor_union [CATEGORICAL]
5 : race [CATEGORICAL]
2 : hispanic_origin [CATEGORICAL]
1 : reason_for_unemployment [CATEGORICAL]
Attribute in nodes with depth <= 2:
399 : detailed_occupation_recode [CATEGORICAL]
249 : detailed_industry_recode [CATEGORICAL]
170 : capital_gains [NUMERICAL]
117 : dividends_from_stocks [NUMERICAL]
116 : capital_losses [NUMERICAL]
87 : education [CATEGORICAL]
59 : wage_per_hour [NUMERICAL]
45 : detailed_household_and_family_stat [CATEGORICAL]
43 : country_of_birth_father [CATEGORICAL]
43 : age [NUMERICAL]
40 : country_of_birth_self [CATEGORICAL]
38 : state_of_previous_residence [CATEGORICAL]
38 : class_of_worker [CATEGORICAL]
37 : sex [CATEGORICAL]
36 : weeks_worked_in_year [NUMERICAL]
33 : country_of_birth_mother [CATEGORICAL]
28 : num_persons_worked_for_employer [NUMERICAL]
26 : tax_filer_stat [CATEGORICAL]
14 : own_business_or_self_employed [CATEGORICAL]
14 : marital_stat [CATEGORICAL]
12 : full_or_part_time_employment_stat [CATEGORICAL]
12 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]
8 : member_of_a_labor_union [CATEGORICAL]
6 : race [CATEGORICAL]
6 : hispanic_origin [CATEGORICAL]
2 : major_occupation_code [CATEGORICAL]
2 : major_industry_code [CATEGORICAL]
1 : reason_for_unemployment [CATEGORICAL]
1 : migration_prev_res_in_sunbelt [CATEGORICAL]
1 : migration_code-change_in_reg [CATEGORICAL]
Attribute in nodes with depth <= 3:
785 : detailed_occupation_recode [CATEGORICAL]
668 : detailed_industry_recode [CATEGORICAL]
275 : capital_gains [NUMERICAL]
220 : dividends_from_stocks [NUMERICAL]
197 : capital_losses [NUMERICAL]
178 : education [CATEGORICAL]
128 : country_of_birth_mother [CATEGORICAL]
116 : country_of_birth_father [CATEGORICAL]
114 : age [NUMERICAL]
98 : wage_per_hour [NUMERICAL]
95 : state_of_previous_residence [CATEGORICAL]
78 : detailed_household_and_family_stat [CATEGORICAL]
67 : class_of_worker [CATEGORICAL]
65 : sex [CATEGORICAL]
65 : country_of_birth_self [CATEGORICAL]
60 : weeks_worked_in_year [NUMERICAL]
57 : tax_filer_stat [CATEGORICAL]
54 : num_persons_worked_for_employer [NUMERICAL]
30 : own_business_or_self_employed [CATEGORICAL]
26 : marital_stat [CATEGORICAL]
16 : member_of_a_labor_union [CATEGORICAL]
15 : major_industry_code [CATEGORICAL]
15 : full_or_part_time_employment_stat [CATEGORICAL]
15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]
9 : hispanic_origin [CATEGORICAL]
7 : race [CATEGORICAL]
7 : major_occupation_code [CATEGORICAL]
1 : veterans_benefits [CATEGORICAL]
1 : region_of_previous_residence [CATEGORICAL]
1 : reason_for_unemployment [CATEGORICAL]
1 : migration_prev_res_in_sunbelt [CATEGORICAL]
1 : migration_code-change_in_reg [CATEGORICAL]
1 : detailed_household_summary_in_household [CATEGORICAL]
1 : citizenship [CATEGORICAL]
Attribute in nodes with depth <= 5:
785 : detailed_occupation_recode [CATEGORICAL]
668 : detailed_industry_recode [CATEGORICAL]
275 : capital_gains [NUMERICAL]
220 : dividends_from_stocks [NUMERICAL]
197 : capital_losses [NUMERICAL]
178 : education [CATEGORICAL]
128 : country_of_birth_mother [CATEGORICAL]
116 : country_of_birth_father [CATEGORICAL]
114 : age [NUMERICAL]
98 : wage_per_hour [NUMERICAL]
95 : state_of_previous_residence [CATEGORICAL]
78 : detailed_household_and_family_stat [CATEGORICAL]
67 : class_of_worker [CATEGORICAL]
65 : sex [CATEGORICAL]
65 : country_of_birth_self [CATEGORICAL]
60 : weeks_worked_in_year [NUMERICAL]
57 : tax_filer_stat [CATEGORICAL]
54 : num_persons_worked_for_employer [NUMERICAL]
30 : own_business_or_self_employed [CATEGORICAL]
26 : marital_stat [CATEGORICAL]
16 : member_of_a_labor_union [CATEGORICAL]
15 : major_industry_code [CATEGORICAL]
15 : full_or_part_time_employment_stat [CATEGORICAL]
15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]
9 : hispanic_origin [CATEGORICAL]
7 : race [CATEGORICAL]
7 : major_occupation_code [CATEGORICAL]
1 : veterans_benefits [CATEGORICAL]
1 : region_of_previous_residence [CATEGORICAL]
1 : reason_for_unemployment [CATEGORICAL]
1 : migration_prev_res_in_sunbelt [CATEGORICAL]
1 : migration_code-change_in_reg [CATEGORICAL]
1 : detailed_household_summary_in_household [CATEGORICAL]
1 : citizenship [CATEGORICAL]
Condition type in nodes:
2418 : ContainsBitmapCondition
1018 : HigherCondition
31 : ContainsCondition
Condition type in nodes with depth <= 0:
137 : ContainsBitmapCondition
101 : HigherCondition
7 : ContainsCondition
Condition type in nodes with depth <= 1:
448 : ContainsBitmapCondition
278 : HigherCondition
9 : ContainsCondition
Condition type in nodes with depth <= 2:
1097 : ContainsBitmapCondition
569 : HigherCondition
17 : ContainsCondition
Condition type in nodes with depth <= 3:
2418 : ContainsBitmapCondition
1018 : HigherCondition
31 : ContainsCondition
Condition type in nodes with depth <= 5:
2418 : ContainsBitmapCondition
1018 : HigherCondition
31 : ContainsCondition
None
目标编码是一种常用的分类特征预处理技术,可将其转换为数值特征。直接使用具有高基数的分类特征可能会导致过拟合。目标编码旨在用一个或多个表示其与目标标签共现的数值替换每个分类特征值。
更准确地说,给定一个分类特征,本示例中的二元目标编码器将生成三个新的数值特征
positive_frequency
:每个特征值与正目标标签一起出现的次数。negative_frequency
:每个特征值与负目标标签一起出现的次数。positive_probability
:给定特征值,目标标签为正的概率,计算公式为 positive_frequency / (positive_frequency + negative_frequency + correction)
。添加 correction
项是为了使稀有分类值的除法更稳定。correction
的默认值为 1.0。请注意,目标编码对于无法自动学习分类特征的密集表示的模型(如决策森林或核方法)是有效的。如果使用神经网络模型,建议将分类特征编码为嵌入。
为简单起见,我们假设 adapt
和 call
方法的输入具有预期的数据类型和形状,因此没有添加验证逻辑。
建议将分类特征的 vocabulary_size
传递给 BinaryTargetEncoding
构造函数。如果未指定,将在 adapt()
方法执行期间计算。
class BinaryTargetEncoding(layers.Layer):
def __init__(self, vocabulary_size=None, correction=1.0, **kwargs):
super().__init__(**kwargs)
self.vocabulary_size = vocabulary_size
self.correction = correction
def adapt(self, data):
# data is expected to be an integer numpy array to a Tensor shape [num_exmples, 2].
# This contains feature values for a given feature in the dataset, and target values.
# Convert the data to a tensor.
data = tf.convert_to_tensor(data)
# Separate the feature values and target values
feature_values = tf.cast(data[:, 0], tf.dtypes.int32)
target_values = tf.cast(data[:, 1], tf.dtypes.bool)
# Compute the vocabulary_size of not specified.
if self.vocabulary_size is None:
self.vocabulary_size = tf.unique(feature_values).y.shape[0]
# Filter the data where the target label is positive.
positive_indices = tf.where(condition=target_values)
positive_feature_values = tf.gather_nd(
params=feature_values, indices=positive_indices
)
# Compute how many times each feature value occurred with a positive target label.
positive_frequency = tf.math.unsorted_segment_sum(
data=tf.ones(
shape=(positive_feature_values.shape[0], 1), dtype=tf.dtypes.float64
),
segment_ids=positive_feature_values,
num_segments=self.vocabulary_size,
)
# Filter the data where the target label is negative.
negative_indices = tf.where(condition=tf.math.logical_not(target_values))
negative_feature_values = tf.gather_nd(
params=feature_values, indices=negative_indices
)
# Compute how many times each feature value occurred with a negative target label.
negative_frequency = tf.math.unsorted_segment_sum(
data=tf.ones(
shape=(negative_feature_values.shape[0], 1), dtype=tf.dtypes.float64
),
segment_ids=negative_feature_values,
num_segments=self.vocabulary_size,
)
# Compute positive probability for the input feature values.
positive_probability = positive_frequency / (
positive_frequency + negative_frequency + self.correction
)
# Concatenate the computed statistics for traget_encoding.
target_encoding_statistics = tf.cast(
tf.concat(
[positive_frequency, negative_frequency, positive_probability], axis=1
),
dtype=tf.dtypes.float32,
)
self.target_encoding_statistics = tf.constant(target_encoding_statistics)
def call(self, inputs):
# inputs is expected to be an integer numpy array to a Tensor shape [num_exmples, 1].
# This includes the feature values for a given feature in the dataset.
# Raise an error if the target encoding statistics are not computed.
if self.target_encoding_statistics == None:
raise ValueError(
f"You need to call the adapt method to compute target encoding statistics."
)
# Convert the inputs to a tensor.
inputs = tf.convert_to_tensor(inputs)
# Cast the inputs int64 a tensor.
inputs = tf.cast(inputs, tf.dtypes.int64)
# Lookup target encoding statistics for the input feature values.
target_encoding_statistics = tf.cast(
tf.gather_nd(self.target_encoding_statistics, inputs),
dtype=tf.dtypes.float32,
)
return target_encoding_statistics
让我们测试二元目标编码器
data = tf.constant(
[
[0, 1],
[2, 0],
[0, 1],
[1, 1],
[1, 1],
[2, 0],
[1, 0],
[0, 1],
[2, 1],
[1, 0],
[0, 1],
[2, 0],
[0, 1],
[1, 1],
[1, 1],
[2, 0],
[1, 0],
[0, 1],
[2, 0],
]
)
binary_target_encoder = BinaryTargetEncoding()
binary_target_encoder.adapt(data)
print(binary_target_encoder([[0], [1], [2]]))
tf.Tensor(
[[6. 0. 0.85714287]
[4. 3. 0.5 ]
[1. 5. 0.14285715]], shape=(3, 3), dtype=float32)
def create_model_inputs():
inputs = {}
for feature_name in NUMERIC_FEATURE_NAMES:
inputs[feature_name] = layers.Input(
name=feature_name, shape=(), dtype=tf.float32
)
for feature_name in CATEGORICAL_FEATURE_NAMES:
inputs[feature_name] = layers.Input(
name=feature_name, shape=(), dtype=tf.string
)
return inputs
def create_target_encoder():
inputs = create_model_inputs()
target_values = train_data[[TARGET_COLUMN_NAME]].to_numpy()
encoded_features = []
for feature_name in inputs:
if feature_name in CATEGORICAL_FEATURE_NAMES:
# Get the vocabulary of the categorical feature.
vocabulary = sorted(
[str(value) for value in list(train_data[feature_name].unique())]
)
# Create a lookup to convert string values to an integer indices.
# Since we are not using a mask token nor expecting any out of vocabulary
# (oov) token, we set mask_token to None and num_oov_indices to 0.
lookup = layers.StringLookup(
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
)
# Convert the string input values into integer indices.
value_indices = lookup(inputs[feature_name])
# Prepare the data to adapt the target encoding.
print("### Adapting target encoding for:", feature_name)
feature_values = train_data[[feature_name]].to_numpy().astype(str)
feature_value_indices = lookup(feature_values)
data = tf.concat([feature_value_indices, target_values], axis=1)
feature_encoder = BinaryTargetEncoding()
feature_encoder.adapt(data)
# Convert the feature value indices to target encoding representations.
encoded_feature = feature_encoder(tf.expand_dims(value_indices, -1))
else:
# Expand the dimensions of the numerical input feature and use it as-is.
encoded_feature = tf.expand_dims(inputs[feature_name], -1)
# Add the encoded feature to the list.
encoded_features.append(encoded_feature)
# Concatenate all the encoded features.
encoded_features = tf.concat(encoded_features, axis=1)
# Create and return a Keras model with encoded features as outputs.
return keras.Model(inputs=inputs, outputs=encoded_features)
在这种情况下,我们将目标编码用作梯度提升树模型的预处理器,并让模型推断输入特征的语义。
def create_gbt_with_preprocessor(preprocessor):
gbt_model = tfdf.keras.GradientBoostedTreesModel(
preprocessing=preprocessor,
num_trees=NUM_TREES,
max_depth=MAX_DEPTH,
min_examples=MIN_EXAMPLES,
subsample=SUBSAMPLE,
validation_ratio=VALIDATION_RATIO,
task=tfdf.keras.Task.CLASSIFICATION,
)
gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])
return gbt_model
gbt_model = create_gbt_with_preprocessor(create_target_encoder())
run_experiment(gbt_model, train_data, test_data)
### Adapting target encoding for: class_of_worker
### Adapting target encoding for: detailed_industry_recode
### Adapting target encoding for: detailed_occupation_recode
### Adapting target encoding for: education
### Adapting target encoding for: enroll_in_edu_inst_last_wk
### Adapting target encoding for: marital_stat
### Adapting target encoding for: major_industry_code
### Adapting target encoding for: major_occupation_code
### Adapting target encoding for: race
### Adapting target encoding for: hispanic_origin
### Adapting target encoding for: sex
### Adapting target encoding for: member_of_a_labor_union
### Adapting target encoding for: reason_for_unemployment
### Adapting target encoding for: full_or_part_time_employment_stat
### Adapting target encoding for: tax_filer_stat
### Adapting target encoding for: region_of_previous_residence
### Adapting target encoding for: state_of_previous_residence
### Adapting target encoding for: detailed_household_and_family_stat
### Adapting target encoding for: detailed_household_summary_in_household
### Adapting target encoding for: migration_code-change_in_msa
### Adapting target encoding for: migration_code-change_in_reg
### Adapting target encoding for: migration_code-move_within_reg
### Adapting target encoding for: live_in_this_house_1_year_ago
### Adapting target encoding for: migration_prev_res_in_sunbelt
### Adapting target encoding for: family_members_under_18
### Adapting target encoding for: country_of_birth_father
### Adapting target encoding for: country_of_birth_mother
### Adapting target encoding for: country_of_birth_self
### Adapting target encoding for: citizenship
### Adapting target encoding for: own_business_or_self_employed
### Adapting target encoding for: fill_inc_questionnaire_for_veteran's_admin
### Adapting target encoding for: veterans_benefits
### Adapting target encoding for: year
Use /tmp/tmpj_0h78ld as temporary training directory
Starting reading the dataset
198/200 [============================>.] - ETA: 0s
Dataset read in 0:00:06.793717
Training model
Model trained in 0:04:32.752691
Compiling model
200/200 [==============================] - 280s 1s/step
Test accuracy: 95.81%
在这种情况下,我们构建一个编码器模型,将分类特征编码为嵌入,其中给定分类特征的嵌入大小是其词汇大小的平方根。
我们在简单的神经网络模型中通过反向传播训练这些嵌入。训练嵌入编码器后,我们将其用作梯度提升树模型输入特征的预处理器。
请注意,嵌入和决策森林模型不能在一个阶段协同训练,因为决策森林模型不使用反向传播进行训练。相反,嵌入必须在初始阶段进行训练,然后用作决策森林模型的静态输入。
def create_embedding_encoder(size=None):
inputs = create_model_inputs()
encoded_features = []
for feature_name in inputs:
if feature_name in CATEGORICAL_FEATURE_NAMES:
# Get the vocabulary of the categorical feature.
vocabulary = sorted(
[str(value) for value in list(train_data[feature_name].unique())]
)
# Create a lookup to convert string values to an integer indices.
# Since we are not using a mask token nor expecting any out of vocabulary
# (oov) token, we set mask_token to None and num_oov_indices to 0.
lookup = layers.StringLookup(
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
)
# Convert the string input values into integer indices.
value_index = lookup(inputs[feature_name])
# Create an embedding layer with the specified dimensions
vocabulary_size = len(vocabulary)
embedding_size = int(math.sqrt(vocabulary_size))
feature_encoder = layers.Embedding(
input_dim=len(vocabulary), output_dim=embedding_size
)
# Convert the index values to embedding representations.
encoded_feature = feature_encoder(value_index)
else:
# Expand the dimensions of the numerical input feature and use it as-is.
encoded_feature = tf.expand_dims(inputs[feature_name], -1)
# Add the encoded feature to the list.
encoded_features.append(encoded_feature)
# Concatenate all the encoded features.
encoded_features = layers.concatenate(encoded_features, axis=1)
# Apply dropout.
encoded_features = layers.Dropout(rate=0.25)(encoded_features)
# Perform non-linearity projection.
encoded_features = layers.Dense(
units=size if size else encoded_features.shape[-1], activation="gelu"
)(encoded_features)
# Create and return a Keras model with encoded features as outputs.
return keras.Model(inputs=inputs, outputs=encoded_features)
def create_nn_model(encoder):
inputs = create_model_inputs()
embeddings = encoder(inputs)
output = layers.Dense(units=1, activation="sigmoid")(embeddings)
nn_model = keras.Model(inputs=inputs, outputs=output)
nn_model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(),
metrics=[keras.metrics.BinaryAccuracy("accuracy")],
)
return nn_model
embedding_encoder = create_embedding_encoder(size=64)
run_experiment(
create_nn_model(embedding_encoder),
train_data,
test_data,
num_epochs=5,
batch_size=256,
)
Epoch 1/5
200/200 [==============================] - 10s 27ms/step - loss: 8303.1455 - accuracy: 0.9193
Epoch 2/5
200/200 [==============================] - 5s 27ms/step - loss: 1019.4900 - accuracy: 0.9371
Epoch 3/5
200/200 [==============================] - 5s 27ms/step - loss: 612.2844 - accuracy: 0.9416
Epoch 4/5
200/200 [==============================] - 5s 27ms/step - loss: 858.9774 - accuracy: 0.9397
Epoch 5/5
200/200 [==============================] - 5s 26ms/step - loss: 842.3922 - accuracy: 0.9421
Test accuracy: 95.0%
gbt_model = create_gbt_with_preprocessor(embedding_encoder)
run_experiment(gbt_model, train_data, test_data)
Use /tmp/tmpao5o88p6 as temporary training directory
Starting reading the dataset
199/200 [============================>.] - ETA: 0s
Dataset read in 0:00:06.722677
Training model
Model trained in 0:05:18.350298
Compiling model
200/200 [==============================] - 325s 2s/step
Test accuracy: 95.82%
TensorFlow 决策森林提供了强大的模型,尤其是在结构化数据方面。在我们的实验中,梯度提升树模型实现了 95.79% 的测试准确率。当对分类特征使用目标编码时,同一模型实现了 95.81% 的测试准确率。当预训练嵌入以用作梯度提升树模型的输入时,我们实现了 95.82% 的测试准确率。
决策森林可以与神经网络结合使用,方式有两种:1)使用神经网络学习输入数据的有用表示,然后使用决策森林进行监督学习任务;或者 2)创建决策森林和神经网络模型的集成。
请注意,TensorFlow 决策森林(目前)不支持硬件加速器。所有的训练和推理都在 CPU 上完成。此外,决策森林需要一个适合内存的有限数据集来进行训练过程。但是,增加数据集的大小带来的回报会递减,并且可以说,决策森林算法比大型神经网络模型需要更少的示例才能收敛。