作者: fchollet
创建日期 2020/06/09
最后修改日期 2020/06/09
描述: 结构化数据的二元分类,包括数值和类别特征。
本示例演示了如何从原始 CSV 文件开始进行结构化数据分类。我们的数据包括数值和类别特征。我们将使用 Keras 预处理层来标准化数值特征和向量化类别特征。
请注意,本示例应在 TensorFlow 2.5 或更高版本上运行。
我们的数据集由克利夫兰诊所基金会提供,用于心脏病研究。它是一个 CSV 文件,包含 303 行。每行包含关于患者的信息(一个 样本),每列描述患者的一个属性(一个 特征)。我们使用这些特征来预测患者是否患有心脏病(二元分类)。
以下是每个特征的描述
列 | 描述 | 特征类型 |
---|---|---|
年龄 | 年龄(年) | 数值型 |
性别 | (1 = 男性;0 = 女性) | 类别型 |
CP | 胸痛类型(0、1、2、3、4) | 类别型 |
Trestbpd | 静息血压(入院时以毫米汞柱为单位) | 数值型 |
Chol | 血清胆固醇(毫克/分升) | 数值型 |
FBS | 空腹血糖 > 120 毫克/分升(1 = 真;0 = 假) | 类别型 |
RestECG | 静息心电图结果(0、1、2) | 类别型 |
Thalach | 达到的最大心率 | 数值型 |
Exang | 运动诱发心绞痛(1 = 是;0 = 否) | 类别型 |
Oldpeak | 运动相对于静息引起的 ST 段压低 | 数值型 |
Slope | 峰值运动 ST 段的斜率 | 数值型 |
CA | 荧光造影着色的主要血管数 (0-3) | 数值型和类别型 |
Thal | 3 = 正常;6 = 固定缺陷;7 = 可逆缺陷 | 类别型 |
Target | 心脏病诊断(1 = 真;0 = 假) | Target |
import os
os.environ["KERAS_BACKEND"] = "torch" # or torch, or tensorflow
import pandas as pd
import keras
from keras import layers
让我们下载数据并将其加载到 Pandas 数据帧中
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)
数据集包含 303 个样本,每个样本 14 列(13 个特征,加上目标标签)
dataframe.shape
(303, 14)
以下是一些样本的预览
dataframe.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | 0 |
1 | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | 1 |
2 | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversible | 0 |
3 | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | 0 |
4 | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | 0 |
最后一列“target”表示患者是否患有心脏病 (1) 或未患病 (0)。
让我们将数据拆分为训练集和验证集
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)
print(
f"Using {len(train_dataframe)} samples for training "
f"and {len(val_dataframe)} for validation"
)
Using 242 samples for training and 61 for validation
在这里,我们定义数据集的元数据,这将有助于读取数据并将其解析为输入特征,并根据特征类型对输入特征进行编码。
COLUMN_NAMES = [
"age",
"sex",
"cp",
"trestbps",
"chol",
"fbs",
"restecg",
"thalach",
"exang",
"oldpeak",
"slope",
"ca",
"thal",
"target",
]
# Target feature name.
TARGET_FEATURE_NAME = "target"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
# Categorical features and their vocabulary lists.
# Note that we add 'v=' as a prefix to all categorical feature values to make
# sure that they are treated as strings.
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
feature_name: sorted(
[
# Integer categorcal must be int and string must be str
value if dataframe[feature_name].dtype == "int64" else str(value)
for value in list(dataframe[feature_name].unique())
]
)
for feature_name in COLUMN_NAMES
if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
}
# All features names.
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
)
以下特征是编码为整数的类别特征
sex
cp
fbs
restecg
exang
ca
我们将使用 独热编码 对这些特征进行编码。我们这里有两个选择
CategoryEncoding()
,这需要知道输入值的范围,并且在输入超出范围时会报错。IntegerLookup()
,它将为输入构建查找表,并为未知输入值保留输出索引。对于本示例,我们想要一个简单的解决方案,可以处理推理时超出范围的输入,因此我们将使用 IntegerLookup()
。
我们还有一个编码为字符串的类别特征:thal
。我们将创建所有可能特征的索引,并使用 StringLookup()
层对输出进行编码。
最后,以下特征是连续数值特征
age
trestbps
chol
thalach
oldpeak
slope
对于这些特征中的每一个,我们将使用 Normalization()
层来确保每个特征的均值为 0,标准差为 1。
下面,我们定义 2 个实用函数来执行操作
encode_numerical_feature
对数值特征应用特征级标准化。process
对字符串或整数类别特征进行独热编码。# Tensorflow required for tf.data.Dataset
import tensorflow as tf
# We process our datasets elements here (categorical) and convert them to indices to avoid this step
# during model training since only tensorflow support strings.
def encode_categorical(features, target):
for feature_name in features:
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
lookup_class = (
layers.StringLookup
if features[feature_name].dtype == "string"
else layers.IntegerLookup
)
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
# Create a lookup to convert a string values to an integer indices.
# Since we are not using a mask token nor expecting any out of vocabulary
# (oov) token, we set mask_token to None and num_oov_indices to 0.
index = lookup_class(
vocabulary=vocabulary,
mask_token=None,
num_oov_indices=0,
output_mode="binary",
)
# Convert the string input values into integer indices.
value_index = index(features[feature_name])
features[feature_name] = value_index
else:
pass
# Change features from OrderedDict to Dict to match Inputs as they are Dict.
return dict(features), target
def encode_numerical_feature(feature, name, dataset):
# Create a Normalization layer for our feature
normalizer = layers.Normalization()
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the statistics of the data
normalizer.adapt(feature_ds)
# Normalize the input feature
encoded_feature = normalizer(feature)
return encoded_feature
让我们为每个数据帧生成 tf.data.Dataset
对象
def dataframe_to_dataset(dataframe):
dataframe = dataframe.copy()
labels = dataframe.pop("target")
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels)).map(
encode_categorical
)
ds = ds.shuffle(buffer_size=len(dataframe))
return ds
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
每个 Dataset
产生一个元组 (input, target)
,其中 input
是特征字典,target
是值 0
或 1
for x, y in train_ds.take(1):
print("Input:", x)
print("Target:", y)
Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=45>, 'sex': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, 'cp': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 0, 1])>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=142>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=309>, 'fbs': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'restecg': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 0, 1])>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=147>, 'exang': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.0>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'ca': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 0, 0, 1])>, 'thal': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 0, 1])>}
Target: tf.Tensor(1, shape=(), dtype=int64)
让我们批量处理数据集
train_ds = train_ds.batch(32)
val_ds = val_ds.batch(32)
完成此操作后,我们可以创建端到端模型
# Categorical features have different shapes after the encoding, dependent on the
# vocabulary or unique values of each feature. We create them accordinly to match the
# input data elements generated by tf.data.Dataset after pre-processing them
def create_model_inputs():
inputs = {}
# This a helper function for creating categorical features
def create_input_helper(feature_name):
num_categories = len(CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name])
inputs[feature_name] = layers.Input(
name=feature_name, shape=(num_categories,), dtype="int64"
)
return inputs
for feature_name in FEATURE_NAMES:
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
# Categorical features
create_input_helper(feature_name)
else:
# Make them float32, they are Real numbers
feature_input = layers.Input(name=feature_name, shape=(1,), dtype="float32")
# Process the Inputs here
inputs[feature_name] = encode_numerical_feature(
feature_input, feature_name, train_ds
)
return inputs
# This Layer defines the logic of the Model to perform the classification
class Classifier(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense_1 = layers.Dense(32, activation="relu")
self.dropout = layers.Dropout(0.5)
self.dense_2 = layers.Dense(1, activation="sigmoid")
def call(self, inputs):
all_features = layers.concatenate(list(inputs.values()))
x = self.dense_1(all_features)
x = self.dropout(x)
output = self.dense_2(x)
return output
# Surpress build warnings
def build(self, input_shape):
self.built = True
# Create the Classifier model
def create_model():
all_inputs = create_model_inputs()
output = Classifier()(all_inputs)
model = keras.Model(all_inputs, output)
return model
model = create_model()
model.compile("adam", "binary_crossentropy", metrics=["accuracy"])
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'age' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor> which has name 'keras_tensor'. Change the tensor name to 'age' (via `Input(..., name='age')`)
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'trestbps' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_1> which has name 'keras_tensor_1'. Change the tensor name to 'trestbps' (via `Input(..., name='trestbps')`)
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'thalach' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_2> which has name 'keras_tensor_2'. Change the tensor name to 'thalach' (via `Input(..., name='thalach')`)
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'oldpeak' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_3> which has name 'keras_tensor_3'. Change the tensor name to 'oldpeak' (via `Input(..., name='oldpeak')`)
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'slope' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_4> which has name 'keras_tensor_4'. Change the tensor name to 'slope' (via `Input(..., name='slope')`)
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'chol' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_5> which has name 'keras_tensor_5'. Change the tensor name to 'chol' (via `Input(..., name='chol')`)
warnings.warn(
让我们可视化我们的连接图
# `rankdir='LR'` is to make the graph horizontal.
keras.utils.plot_model(model, show_shapes=True, rankdir="LR")
model.fit(train_ds, epochs=50, validation_data=val_ds)
Epoch 1/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 102毫秒/步 - 准确率: 0.4688 - 损失: 8.0563
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 8毫秒/步 - 准确率: 0.4732 - 损失: 7.9796
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.4725 - 损失: 7.9848 - 验证准确率: 0.2295 - 验证损失: 12.0816
Epoch 2/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 105毫秒/步 - 准确率: 0.5000 - 损失: 6.6368
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 8毫秒/步 - 准确率: 0.4532 - 损失: 7.8320
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 18毫秒/步 - 准确率: 0.4547 - 损失: 7.8310 - 验证准确率: 0.2459 - 验证损失: 6.2543
Epoch 3/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 91毫秒/步 - 准确率: 0.5000 - 损失: 7.6558
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.5041 - 损失: 7.3378
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 19毫秒/步 - 准确率: 0.5087 - 损失: 7.2802 - 验证准确率: 0.6885 - 验证损失: 2.1633
Epoch 4/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 93毫秒/步 - 准确率: 0.4375 - 损失: 8.9030
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 8毫秒/步 - 准确率: 0.4815 - 损失: 8.0109
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 19毫秒/步 - 准确率: 0.4858 - 损失: 7.9351 - 验证准确率: 0.7705 - 验证损失: 3.3916
Epoch 5/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 88毫秒/步 - 准确率: 0.4688 - 损失: 8.1279
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.5049 - 损失: 7.4815
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.5117 - 损失: 7.4054 - 验证准确率: 0.7705 - 验证损失: 3.6911
Epoch 6/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 107毫秒/步 - 准确率: 0.4688 - 损失: 7.8832
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.4940 - 损失: 7.4615
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.5121 - 损失: 7.1851 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 7/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 101毫秒/步 - 准确率: 0.5312 - 损失: 6.9446
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 12毫秒/步 - 准确率: 0.5357 - 损失: 6.5511
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.5497 - 损失: 6.3711 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 8/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 110毫秒/步 - 准确率: 0.5938 - 损失: 6.3905
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.6192 - 损失: 5.9601
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6101 - 损失: 6.0728 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 9/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 108毫秒/步 - 准确率: 0.5938 - 损失: 6.5442
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.6006 - 损失: 6.3309
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 21毫秒/步 - 准确率: 0.5949 - 损失: 6.3647 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 10/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 113毫秒/步 - 准确率: 0.5625 - 损失: 6.8250
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 10毫秒/步 - 准确率: 0.5675 - 损失: 6.5020
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.5764 - 损失: 6.3308 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 11/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 116毫秒/步 - 准确率: 0.6250 - 损失: 4.3582
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.6053 - 损失: 5.4824
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6076 - 损失: 5.4500 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 12/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 118毫秒/步 - 准确率: 0.5625 - 损失: 7.0064
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.5740 - 损失: 6.4431
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 21毫秒/步 - 准确率: 0.5787 - 损失: 6.3510 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 13/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 115毫秒/步 - 准确率: 0.7500 - 损失: 3.7382
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 10毫秒/步 - 准确率: 0.6812 - 损失: 4.7893
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 21毫秒/步 - 准确率: 0.6712 - 损失: 4.9453 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 14/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 114毫秒/步 - 准确率: 0.6562 - 损失: 5.5498
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.6580 - 损失: 5.4636
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 20毫秒/步 - 准确率: 0.6578 - 损失: 5.4379 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 15/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 113毫秒/步 - 准确率: 0.5938 - 损失: 5.8118
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 8毫秒/步 - 准确率: 0.5978 - 损失: 5.9295
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 21毫秒/步 - 准确率: 0.6045 - 损失: 5.8426 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 16/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 115毫秒/步 - 准确率: 0.6562 - 损失: 4.4893
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.5763 - 损失: 5.9135
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.5814 - 损失: 5.8590 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 17/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 127毫秒/步 - 准确率: 0.5625 - 损失: 7.0281
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.6071 - 损失: 6.0424
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 24毫秒/步 - 准确率: 0.6179 - 损失: 5.8262 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 18/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 130毫秒/步 - 准确率: 0.6562 - 损失: 5.3547
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.6701 - 损失: 5.0648
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 25毫秒/步 - 准确率: 0.6713 - 损失: 5.0607 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 19/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 121毫秒/步 - 准确率: 0.7500 - 损失: 4.0295
5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0秒 13毫秒/步 - 准确率: 0.7157 - 损失: 4.3995
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 25毫秒/步 - 准确率: 0.7077 - 损失: 4.4886 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 20/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 129毫秒/步 - 准确率: 0.6250 - 损失: 6.0278
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.6479 - 损失: 5.4982
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 24毫秒/步 - 准确率: 0.6461 - 损失: 5.4898 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 21/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 134毫秒/步 - 准确率: 0.5938 - 损失: 5.8592
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.6782 - 损失: 4.7529
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 23毫秒/步 - 准确率: 0.6627 - 损失: 5.0219 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 22/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 127毫秒/步 - 准确率: 0.6875 - 损失: 5.0149
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.6342 - 损失: 5.5898
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 25毫秒/步 - 准确率: 0.6290 - 损失: 5.6701 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 23/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 121毫秒/步 - 准确率: 0.5938 - 损失: 6.0783
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.6259 - 损失: 5.6908
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 24毫秒/步 - 准确率: 0.6352 - 损失: 5.5719 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 24/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 112毫秒/步 - 准确率: 0.7812 - 损失: 3.1021
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 12毫秒/步 - 准确率: 0.7353 - 损失: 3.8725
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 23毫秒/步 - 准确率: 0.7163 - 损失: 4.1637 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 25/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 112毫秒/步 - 准确率: 0.5625 - 损失: 6.9224
5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0秒 13毫秒/步 - 准确率: 0.6331 - 损失: 5.5663
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 23毫秒/步 - 准确率: 0.6416 - 损失: 5.4024 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 26/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 117毫秒/步 - 准确率: 0.6875 - 损失: 4.4043
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.6668 - 损失: 5.0742
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6743 - 损失: 4.9986 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 27/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 104毫秒/步 - 准确率: 0.6562 - 损失: 5.3405
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 8毫秒/步 - 准确率: 0.6868 - 损失: 4.7990
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 20毫秒/步 - 准确率: 0.6838 - 损失: 4.8458 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 28/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 116毫秒/步 - 准确率: 0.6562 - 损失: 4.8092
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.7061 - 损失: 4.3996
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 20毫秒/步 - 准确率: 0.7053 - 损失: 4.4297 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 29/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 114毫秒/步 - 准确率: 0.6250 - 损失: 5.6655
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 10毫秒/步 - 准确率: 0.6536 - 损失: 5.3912
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 21毫秒/步 - 准确率: 0.6589 - 损失: 5.3014 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 30/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 116毫秒/步 - 准确率: 0.7812 - 损失: 3.5258
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.6900 - 损失: 4.7711
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 20毫秒/步 - 准确率: 0.6882 - 损失: 4.8074 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 31/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 123毫秒/步 - 准确率: 0.5938 - 损失: 6.5425
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 10毫秒/步 - 准确率: 0.6346 - 损失: 5.6779
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6423 - 损失: 5.5672 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 32/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 120毫秒/步 - 准确率: 0.6250 - 损失: 5.6215
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.6451 - 损失: 5.2140
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 23毫秒/步 - 准确率: 0.6556 - 损失: 5.0993 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 33/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 115毫秒/步 - 准确率: 0.7188 - 损失: 4.2096
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.7218 - 损失: 4.3075
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 23毫秒/步 - 准确率: 0.7143 - 损失: 4.4143 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 34/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 114毫秒/步 - 准确率: 0.5625 - 损失: 7.0242
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.6608 - 损失: 5.3428
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 24毫秒/步 - 准确率: 0.6675 - 损失: 5.2031 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 35/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 105毫秒/步 - 准确率: 0.6875 - 损失: 5.0369
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.6601 - 损失: 5.2386
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 24毫秒/步 - 准确率: 0.6675 - 损失: 5.0972 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 36/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 114毫秒/步 - 准确率: 0.6562 - 损失: 4.8957
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.7086 - 损失: 4.4144
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 23毫秒/步 - 准确率: 0.6980 - 损失: 4.5912 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 37/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 115毫秒/步 - 准确率: 0.6250 - 损失: 6.0333
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.6438 - 损失: 5.6852
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 23毫秒/步 - 准确率: 0.6551 - 损失: 5.4504 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 38/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 119毫秒/步 - 准确率: 0.5938 - 损失: 6.4043
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.6659 - 损失: 5.2220
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6751 - 损失: 5.0637 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 39/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 122毫秒/步 - 准确率: 0.5625 - 损失: 7.0517
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 10毫秒/步 - 准确率: 0.6782 - 损失: 5.0396
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6854 - 损失: 4.9129 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 40/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 121毫秒/步 - 准确率: 0.6562 - 损失: 5.4278
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.6575 - 损失: 5.2183
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6676 - 损失: 5.0430 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 41/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 120毫秒/步 - 准确率: 0.7500 - 损失: 3.9611
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 10毫秒/步 - 准确率: 0.7322 - 损失: 4.2233
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 24毫秒/步 - 准确率: 0.7325 - 损失: 4.2274 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 42/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 127毫秒/步 - 准确率: 0.8438 - 损失: 2.5075
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.7483 - 损失: 3.8605
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 26毫秒/步 - 准确率: 0.7305 - 损失: 4.1423 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 43/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 132毫秒/步 - 准确率: 0.7188 - 损失: 4.5277
5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0秒 15毫秒/步 - 准确率: 0.6698 - 损失: 5.2541
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 27毫秒/步 - 准确率: 0.6831 - 损失: 4.9995 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 44/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 1秒 149毫秒/步 - 准确率: 0.7188 - 损失: 4.3368
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 12毫秒/步 - 准确率: 0.6884 - 损失: 4.8941
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 26毫秒/步 - 准确率: 0.6877 - 损失: 4.9237 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 45/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 113毫秒/步 - 准确率: 0.7188 - 损失: 3.6048
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.6953 - 损失: 4.5189
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 24毫秒/步 - 准确率: 0.6914 - 损失: 4.6078 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 46/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 120毫秒/步 - 准确率: 0.7188 - 损失: 4.5277
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.7298 - 损失: 4.2710
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 25毫秒/步 - 准确率: 0.7214 - 损失: 4.4175 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 47/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 117毫秒/步 - 准确率: 0.7500 - 损失: 4.0295
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.6962 - 损失: 4.8892
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 26毫秒/步 - 准确率: 0.6981 - 损失: 4.8478 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 48/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 122毫秒/步 - 准确率: 0.7812 - 损失: 3.4540
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 9毫秒/步 - 准确率: 0.7095 - 损失: 4.5553
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 21毫秒/步 - 准确率: 0.7080 - 损失: 4.5585 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 49/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 117毫秒/步 - 准确率: 0.6875 - 损失: 4.5707
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0秒 10毫秒/步 - 准确率: 0.6914 - 损失: 4.7756
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6939 - 损失: 4.7445 - 验证准确率: 0.7705 - 验证损失: 3.6992
Epoch 50/50
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0秒 124毫秒/步 - 准确率: 0.7188 - 损失: 4.0735
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0秒 11毫秒/步 - 准确率: 0.7049 - 损失: 4.3802
8/8 ━━━━━━━━━━━━━━━━━━━━ 0秒 22毫秒/步 - 准确率: 0.6987 - 损失: 4.5132 - 验证准确率: 0.7705 - 验证损失: 3.6992
<keras.src.callbacks.history.History at 0x747bef08e590>
我们很快获得了 80% 的验证准确率。
要获得新样本的预测,您可以简单地调用 model.predict()
。您只需要做两件事
convert_to_tensor
sample = {
"age": 60,
"sex": 1,
"cp": 1,
"trestbps": 145,
"chol": 233,
"fbs": 1,
"restecg": 2,
"thalach": 150,
"exang": 0,
"oldpeak": 2.3,
"slope": 3,
"ca": 0,
"thal": "fixed",
}
# Given the category (in the sample above - key) and the category value (in the sample above - value),
# we return its one-hot encoding
def get_cat_encoding(cat, cat_value):
# Create a list of zeros with the same length as categories
encoding = [0] * len(cat)
# Find the index of category_value in categories and set the corresponding position to 1
if cat_value in cat:
encoding[cat.index(cat_value)] = 1
return encoding
for name, value in sample.items():
if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
sample.update(
{
name: get_cat_encoding(
CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
)
}
)
# Convert inputs to tensors
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = model.predict(input_dict)
print(
f"This particular patient had a {100 * predictions[0][0]:.1f} "
"percent probability of having a heart disease, "
"as evaluated by our model."
)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0秒 77毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 0秒 79毫秒/步
This particular patient had a 0.0 percent probability of having a heart disease, as evaluated by our model.