代码示例 / 结构化数据 / 使用超参数调优的门控残差网络和变量选择网络进行分类

使用门控残差和变量选择网络进行分类并进行超参数调优

作者: Humbulani Ndou
创建日期 2025/03/17
最后修改日期 2025/03/17
描述: 对使用超参数调优的门控残差网络和变量选择网络进行预测。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源码


简介

下面的示例扩展了脚本 structured_data/classification_with_grn_and_vsn.py,通过集成使用 AutokerasKerasTuner 的超参数调优。关于从这两个包中使用的 API 的具体细节将在相关代码部分详细介绍。

本示例演示了 Bryan Lim 等人在 面向可解释多步时间序列预测的时序融合 Transformer (TFT) 中提出的门控残差网络 (GRN) 和变量选择网络 (VSN) 在结构化数据分类中的应用。GRN 赋予模型仅在需要时应用非线性处理的灵活性。VSN 允许模型软性地移除任何可能对性能产生负面影响的不必要噪声输入。这些技术结合起来,有助于提高深度神经网络模型的学习能力。

请注意,本示例仅实现了论文中描述的 GRN 和 VSN 组件,而不是整个 TFT 模型,因为 GRN 和 VSN 本身也可用于结构化数据学习任务。

要运行代码,您需要使用 TensorFlow 2.3 或更高版本。


数据集

我们的数据集 由克利夫兰诊所心脏病部门提供。它是一个包含 303 行的 CSV 文件。每一行包含有关患者(一个样本)的信息,每一列描述患者的一个属性(一个特征)。我们使用这些特征来预测患者是否患有心脏病(二元分类)。

以下是每个特征的描述

描述 特征类型
年龄 年龄(以年为单位) 数值型
性别 (1 = 男性;0 = 女性) 类别型
CP 胸痛类型(0、1、2、3、4) 类别型
Trestbpd 静息血压(入院时以毫米汞柱为单位) 数值型
Chol 血清胆固醇(以毫克/分升为单位) 数值型
FBS 空腹血糖(以 120 毫克/分升为单位)(1 = 是;0 = 否) 类别型
RestECG 静息心电图结果(0、1、2) 类别型
Thalach 达到的最高心率 数值型
Exang 运动引起的心绞痛(1 = 是;0 = 否) 类别型
Oldpeak 运动相对于静息引起的 ST 段压低 数值型
Slope 运动峰值 ST 段的斜率 数值型
CA 通过荧光镜检查着色的主要血管数量(0-3) 数值型和类别型均可
Thal 3 = 正常;6 = 固定缺损;7 = 可逆缺损 类别型
目标 心脏病诊断(1 = 是;0 = 否) 目标

设置

import os
import subprocess
import tarfile
import numpy as np
import pandas as pd
import tree
from typing import Optional, Union

os.environ["KERAS_BACKEND"] = "tensorflow"  # or jax, or torch

# Keras imports
import keras
from keras import layers

# KerasTuner imports
import keras_tuner
from keras_tuner import HyperParameters

# AutoKeras imports
import autokeras as ak
from autokeras.utils import utils, types

准备数据

让我们下载数据并将其加载到 Pandas 数据框中

file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)

该数据集包含 303 个样本,每个样本有 14 列(13 个特征加上目标标签)

dataframe.shape
(303, 14)

这是几个样本的预览

dataframe.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal 目标
0 63 1 1 145 233 1 2 150 0 2.3 3 0 fixed 0
1 67 1 4 160 286 0 2 108 1 1.5 2 3 normal 1
2 67 1 4 120 229 0 2 129 1 2.6 2 2 reversible 0
3 37 1 3 130 250 0 0 187 0 3.5 3 0 normal 0
4 41 0 2 130 204 0 2 172 0 1.4 1 0 normal 0

最后一列“target”表示患者是否患有心脏病(1)或没有(0)。

让我们将数据分割成训练集和验证集

val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)

print(
    f"Using {len(train_dataframe)} samples for training "
    f"and {len(val_dataframe)} for validation"
)
Using 242 samples for training and 61 for validation

定义数据集元数据

在这里,我们定义数据集的元数据,这将有助于读取和解析数据为输入特征,并根据其类型对输入特征进行编码。

COLUMN_NAMES = [
    "age",
    "sex",
    "cp",
    "trestbps",
    "chol",
    "fbs",
    "restecg",
    "thalach",
    "exang",
    "oldpeak",
    "slope",
    "ca",
    "thal",
    "target",
]
# Target feature name.
TARGET_FEATURE_NAME = "target"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
# Categorical features and their vocabulary lists.
# Note that we add 'v=' as a prefix to all categorical feature values to make
# sure that they are treated as strings.

CATEGORICAL_FEATURES_WITH_VOCABULARY = {
    feature_name: sorted(
        [
            # Integer categorcal must be int and string must be str
            value if dataframe[feature_name].dtype == "int64" else str(value)
            for value in list(dataframe[feature_name].unique())
        ]
    )
    for feature_name in COLUMN_NAMES
    if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
}
# All features names.
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
    CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
)

使用 Keras 层进行特征预处理

以下特征是编码为整数的分类特征

  • sex
  • cp
  • fbs
  • restecg
  • exang
  • ca

我们将使用独热编码来编码这些特征。我们有两种选择:

  • 使用 CategoryEncoding(),这需要知道输入值的范围,并且在输入超出范围时会报错。
  • 使用 IntegerLookup(),它将为输入构建一个查找表,并为未知输入值保留一个输出索引。

在此示例中,我们想要一个简单的解决方案,能够处理推理时超出范围的输入,因此我们将使用 IntegerLookup()

我们还有一个编码为字符串的分类特征:thal。我们将创建所有可能特征的索引,并使用 StringLookup() 层进行编码输出。

最后,以下特征是连续数值特征

  • age
  • trestbps
  • chol
  • thalach
  • oldpeak
  • slope

对于这些特征中的每一个,我们将使用 Normalization() 层来确保每个特征的均值为 0,标准差为 1。

下面,我们定义一个实用函数来执行这些操作

  • process 函数用于对字符串或整数分类特征进行独热编码。
# Tensorflow required for tf.data.Dataset
import tensorflow as tf


# We process our datasets elements here (categorical) and convert them to indices to avoid this step
# during model training since only tensorflow support strings.
def encode_categorical(features, target):
    for f in features:
        if f in CATEGORICAL_FEATURES_WITH_VOCABULARY:
            # Create a lookup to convert a string values to an integer indices.
            # Since we are not using a mask token nor expecting any out of vocabulary
            # (oov) token, we set mask_token to None and  num_oov_indices to 0.
            cls = (
                layers.StringLookup
                if features[f].dtype == "string"
                else layers.IntegerLookup
            )
            features[f] = cls(
                vocabulary=CATEGORICAL_FEATURES_WITH_VOCABULARY[f],
                mask_token=None,
                num_oov_indices=0,
                output_mode="binary",
            )(features[f])

    # Change features from OrderedDict to Dict to match Inputs as they are Dict.
    return dict(features), target

让我们为每个数据框生成 tf.data.Dataset 对象

def dataframe_to_dataset(dataframe):
    dataframe = dataframe.copy()
    labels = dataframe.pop("target")
    ds = (
        tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
        .map(encode_categorical)
        .shuffle(buffer_size=len(dataframe))
    )
    return ds


train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)

每个 Dataset 都返回一个元组 (input, target),其中 input 是特征字典,target 是值 01

for x, y in train_ds.take(1):
    print("Input:", x)
    print("Target:", y)
Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=37>, 'sex': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'cp': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=120>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=215>, 'fbs': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'restecg': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 0, 0])>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=170>, 'exang': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.0>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'ca': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 0, 0, 0])>, 'thal': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>}
Target: tf.Tensor(0, shape=(), dtype=int64)

让我们对数据集进行批处理

train_ds = train_ds.batch(32)
val_ds = val_ds.batch(32)

子类化 Autokeras Graph

在这里,我们子类化 Autokeras 的 Graph

  • build:我们重写此方法,以便能够处理作为字典传递的 Inputs。在结构化数据分析中,Inputs 通常作为字典为每个感兴趣的特征传递。
class Graph(ak.graph.Graph):

    def build(self, hp):
        """Build the HyperModel into a Keras Model."""
        keras_nodes = {}
        keras_input_nodes = []
        for node in self.inputs:
            node_id = self._node_to_id[node]
            input_node = node.build_node(hp)
            output_node = node.build(hp, input_node)
            keras_input_nodes.append(input_node)
            keras_nodes[node_id] = output_node
        for block in self.blocks:
            temp_inputs = (
                {
                    n.name: keras_nodes[self._node_to_id[n]]
                    for n in block.inputs
                    if isinstance(n, ak.Input)
                }
                if isinstance(block.inputs[0], ak.Input)
                else [keras_nodes[self._node_to_id[n]] for n in block.inputs]
            )
            outputs = tree.flatten(block.build(hp, inputs=temp_inputs))
            for n, o in zip(block.outputs, outputs):
                keras_nodes[self._node_to_id[n]] = o
        model = keras.models.Model(
            keras_input_nodes,
            [
                keras_nodes[self._node_to_id[output_node]]
                for output_node in self.outputs
            ],
        )
        return self._compile_keras_model(hp, model)

    def _compile_keras_model(self, hp, model):
        # Specify hyperparameters from compile(...)
        optimizer_name = hp.Choice(
            "optimizer",
            ["adam", "sgd"],
            default="adam",
        )
        learning_rate = hp.Choice(
            "learning_rate", [1e-1, 1e-2, 1e-3, 1e-4, 2e-5, 1e-5], default=1e-3
        )
        if optimizer_name == "adam":
            optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        elif optimizer_name == "sgd":
            optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
        model.compile(
            optimizer=optimizer,
            metrics=self._get_metrics(),
            loss=self._get_loss(),
        )
        return model

子类化 Autokeras Input

在这里,我们子类化 Autokeras 的 Input 节点对象,并将 dtype 属性从 None 重写为用户提供的值。我们还重写 build_node 方法,以使用用户提供的名称作为 Input 层的名称。

class Input(ak.Input):
    def __init__(self, dtype, name=None, **kwargs):
        super().__init__(name=name, **kwargs)
        # Override dtype to a user dtype value
        self.dtype = dtype
        self.name = name

    def build_node(self, hp):
        return keras.Input(name=self.name, shape=self.shape, dtype=self.dtype)

子类化 ClassificationHead

在这里,我们子类化 Autokeras ClassificationHead 并重写 init 方法,并添加 get_expected_shape 方法来推断标签的形状。我们移除预处理功能,因为我们更倾向于手动进行。

class ClassifierHead(ak.ClassificationHead):

    def __init__(
        self,
        num_classes: Optional[int] = None,
        multi_label: bool = False,
        loss: Optional[types.LossType] = None,
        metrics: Optional[types.MetricsType] = None,
        dropout: Optional[float] = None,
        **kwargs,
    ):
        self.num_classes = num_classes
        self.multi_label = multi_label
        self.dropout = dropout
        if metrics is None:
            metrics = ["accuracy"]
        if loss is None:
            loss = self.infer_loss()
        ak.Head.__init__(self, loss=loss, metrics=metrics, **kwargs)
        self.shape = self.get_expected_shape()

    def get_expected_shape(self):
        # Compute expected shape from num_classes.
        if self.num_classes == 2 and not self.multi_label:
            return [1]
        return [self.num_classes]

GatedLinearUnit 层

这是一个在脚本 structured_data/classification_with_grn_vsn.py 中定义的 Keras 层。关于此层的更多详细信息可在相关脚本中找到。

class GatedLinearUnit(layers.Layer):
    def __init__(self, num_units, activation, **kwargs):
        super().__init__(**kwargs)
        self.linear = layers.Dense(num_units)
        self.sigmoid = layers.Dense(num_units, activation=activation)

    def call(self, inputs):
        return self.linear(inputs) * self.sigmoid(inputs)

    def build(self):
        self.built = True

GatedResidualNetwork 层

这是一个在脚本 structured_data/classification_with_grn_vsn.py 中定义的 Keras 层。关于此层的更多详细信息可在相关脚本中找到。

class GatedResidualNetwork(layers.Layer):

    def __init__(
        self, num_units, dropout_rate, activation, use_layernorm=None, **kwargs
    ):
        super().__init__(**kwargs)
        self.num_units = num_units
        self.use_layernorm = use_layernorm
        self.elu_dense = layers.Dense(num_units, activation=activation)
        self.linear_dense = layers.Dense(num_units)
        self.dropout = layers.Dropout(dropout_rate)
        self.gated_linear_unit = GatedLinearUnit(num_units, activation)
        self.layer_norm = layers.LayerNormalization()
        self.project = layers.Dense(num_units)

    def call(self, inputs, hp):
        x = self.elu_dense(inputs)
        x = self.linear_dense(x)
        x = self.dropout(x)
        if inputs.shape[-1] != self.num_units:
            inputs = self.project(inputs)
        x = inputs + self.gated_linear_unit(x)
        use_layernorm = self.use_layernorm
        if use_layernorm is None:
            use_layernorm = hp.Boolean("use_layernorm", default=True)
        if use_layernorm:
            x = self.layer_norm(x)
        return x

    def build(self):
        self.built = True

构建 Autokeras VariableSelection Block

我们已将以下 Keras 层转换为 Autokeras Block,以包含待调优的超参数。请参阅 Autokeras Blocks API 来编写自定义 Blocks。

class VariableSelection(ak.Block):
    def __init__(
        self,
        num_units: Optional[Union[int, HyperParameters.Choice]] = None,
        dropout_rate: Optional[Union[float, HyperParameters.Choice]] = None,
        activation: Optional[Union[str, HyperParameters.Choice]] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dropout = utils.get_hyperparameter(
            dropout_rate,
            HyperParameters().Choice("dropout", [0.0, 0.25, 0.5], default=0.0),
            float,
        )
        self.num_units = utils.get_hyperparameter(
            num_units,
            HyperParameters().Choice(
                "num_units", [16, 32, 64, 128, 256, 512, 1024], default=16
            ),
            int,
        )
        self.activation = utils.get_hyperparameter(
            activation,
            HyperParameters().Choice(
                "vsn_activation", ["sigmoid", "elu"], default="sigmoid"
            ),
            str,
        )

    def build(self, hp, inputs):
        num_units = utils.add_to_hp(self.num_units, hp, "num_units")
        dropout_rate = utils.add_to_hp(self.dropout, hp, "dropout_rate")
        activation = utils.add_to_hp(self.activation, hp, "activation")
        concat_inputs = []
        # Project the features to 'num_units' dimension
        for input_ in inputs:
            if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
                concat_inputs.append(
                    keras.layers.Dense(units=num_units)(inputs[input_])
                )
            else:
                # Create a Normalization layer for our feature
                normalizer = layers.Normalization()
                # Prepare a Dataset that only yields our feature
                feature_ds = train_ds.map(lambda x, y: x[input_]).map(
                    lambda x: keras.ops.expand_dims(x, -1)
                )
                # Learn the statistics of the data
                normalizer.adapt(feature_ds)
                # Normalize the input feature
                normal_feature = normalizer(inputs[input_])
                concat_inputs.append(
                    keras.layers.Dense(units=num_units)(normal_feature)
                )
        v = layers.concatenate(concat_inputs)
        v = GatedResidualNetwork(
            num_units=num_units, dropout_rate=dropout_rate, activation=activation
        )(v, hp=hp)
        v = keras.ops.expand_dims(
            layers.Dense(units=len(inputs), activation=activation)(v), axis=-1
        )
        x = []
        x += [
            GatedResidualNetwork(num_units, dropout_rate, activation)(i, hp=hp)
            for i in concat_inputs
        ]
        x = keras.ops.stack(x, axis=1)
        return keras.ops.squeeze(
            keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
        )

我们创建 HyperModel(来自 KerasTuner)的 Inputs,这些 Inputs 将被构建为 Keras Input 对象。

# Categorical features have different shapes after the encoding, dependent on the
# vocabulary or unique values of each feature. We create them accordinly to match the
# input data elements generated by tf.data.Dataset after pre-processing them
def create_model_inputs():
    inputs = {
        f: (
            Input(
                name=f,
                shape=(len(CATEGORICAL_FEATURES_WITH_VOCABULARY[f]),),
                dtype="int64",
            )
            if f in CATEGORICAL_FEATURES_WITH_VOCABULARY
            else Input(name=f, shape=(1,), dtype="float32")
        )
        for f in FEATURE_NAMES
    }
    return inputs

KerasTuner HyperModel

在这里,我们使用 Autokeras Functional API 来构建一个 Blocks 组成的网络SSS,它将被构建成一个 KerasTuner HyperModel,最终构建成一个 Keras Model。

class MyHyperModel(keras_tuner.HyperModel):

    def build(self, hp):
        inputs = create_model_inputs()
        features = VariableSelection()(inputs)
        outputs = ClassifierHead(num_classes=2, multi_label=False)(features)
        model = Graph(inputs=inputs, outputs=outputs)
        model = model.build(hp)
        return model

    def fit(self, hp, model, *args, **kwargs):
        return model.fit(
            *args,
            # Tune whether to shuffle the data in each epoch.
            shuffle=hp.Boolean("shuffle"),
            **kwargs,
        )

使用 RandomSearch Tuner 查找最佳超参数

我们使用 RandomSearch Tuner 在搜索空间中查找超参数。我们还显示了搜索空间。

print("Start training and searching for the best model...")

tuner = keras_tuner.RandomSearch(
    MyHyperModel(),
    objective="val_accuracy",
    max_trials=3,
    overwrite=True,
    directory="my_dir",
    project_name="tune_hypermodel",
)

# Show the search space summary
print("Tuner search space summary:\n")
tuner.search_space_summary()
# Search for best model
tuner.search(train_ds, epochs=2, validation_data=val_ds)
Trial 3 Complete [00h 00m 16s]
val_accuracy: 0.8032786846160889
Best val_accuracy So Far: 0.8032786846160889
Total elapsed time: 00h 00m 34s

提取最佳模型

# Get the top model.
models = tuner.get_best_models(num_models=1)
best_model = models[0]
best_model.summary()
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 346 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ age (InputLayer)    │ (None, 1)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ chol (InputLayer)   │ (None, 1)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ oldpeak             │ (None, 1)         │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ slope (InputLayer)  │ (None, 1)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ thalach             │ (None, 1)         │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ trestbps            │ (None, 1)         │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32     │ (None, 1)         │          0 │ age[0][0]         │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ ca (InputLayer)     │ (None, 4)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_2   │ (None, 1)         │          0 │ chol[0][0]        │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cp (InputLayer)     │ (None, 5)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ exang (InputLayer)  │ (None, 2)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ fbs (InputLayer)    │ (None, 2)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_6   │ (None, 1)         │          0 │ oldpeak[0][0]     │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ restecg             │ (None, 3)         │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ sex (InputLayer)    │ (None, 2)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_9   │ (None, 1)         │          0 │ slope[0][0]       │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ thal (InputLayer)   │ (None, 5)         │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_11  │ (None, 1)         │          0 │ thalach[0][0]     │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_12  │ (None, 1)         │          0 │ trestbps[0][0]    │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization       │ (None, 1)         │          3 │ cast_to_float32[ │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_1   │ (None, 4)         │          0 │ ca[0][0]          │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_1     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_3   │ (None, 5)         │          0 │ cp[0][0]          │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_4   │ (None, 2)         │          0 │ exang[0][0]       │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_5   │ (None, 2)         │          0 │ fbs[0][0]         │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_2     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_7   │ (None, 3)         │          0 │ restecg[0][0]     │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_8   │ (None, 2)         │          0 │ sex[0][0]         │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_3     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ cast_to_float32_10  │ (None, 5)         │          0 │ thal[0][0]        │
│ (CastToFloat32)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_4     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_5     │ (None, 1)         │          3 │ cast_to_float32_… │
│ (Normalization)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense (Dense)       │ (None, 16)        │         32 │ normalization[0]… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_1 (Dense)     │ (None, 16)        │         80 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_2 (Dense)     │ (None, 16)        │         32 │ normalization_1[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_3 (Dense)     │ (None, 16)        │         96 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_4 (Dense)     │ (None, 16)        │         48 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_5 (Dense)     │ (None, 16)        │         48 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_6 (Dense)     │ (None, 16)        │         32 │ normalization_2[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_7 (Dense)     │ (None, 16)        │         64 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_8 (Dense)     │ (None, 16)        │         48 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_9 (Dense)     │ (None, 16)        │         32 │ normalization_3[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_10 (Dense)    │ (None, 16)        │         96 │ cast_to_float32_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_11 (Dense)    │ (None, 16)        │         32 │ normalization_4[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_12 (Dense)    │ (None, 16)        │         32 │ normalization_5[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ concatenate         │ (None, 208)       │          0 │ dense[0][0],      │
│ (Concatenate)       │                   │            │ dense_1[0][0],    │
│                     │                   │            │ dense_2[0][0],    │
│                     │                   │            │ dense_3[0][0],    │
│                     │                   │            │ dense_4[0][0],    │
│                     │                   │            │ dense_5[0][0],    │
│                     │                   │            │ dense_6[0][0],    │
│                     │                   │            │ dense_7[0][0],    │
│                     │                   │            │ dense_8[0][0],    │
│                     │                   │            │ dense_9[0][0],    │
│                     │                   │            │ dense_10[0][0],   │
│                     │                   │            │ dense_11[0][0],   │
│                     │                   │            │ dense_12[0][0]    │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      7,536 │ concatenate[0][0] │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_18 (Dense)    │ (None, 13)        │        221 │ gated_residual_n… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ expand_dims         │ (None, 13, 1)     │          0 │ dense_18[0][0]    │
│ (ExpandDims)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense[0][0]       │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_1[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_2[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_3[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_4[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_5[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_6[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_7[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_8[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_9[0][0]     │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_10[0][0]    │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_11[0][0]    │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gated_residual_net… │ (None, 16)        │      1,120 │ dense_12[0][0]    │
│ (GatedResidualNetw… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transpose           │ (None, 1, 13)     │          0 │ expand_dims[0][0] │
│ (Transpose)         │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stack (Stack)       │ (None, 13, 16)    │          0 │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
│                     │                   │            │ gated_residual_n… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ matmul (Matmul)     │ (None, 1, 16)     │          0 │ transpose[0][0],  │
│                     │                   │            │ stack[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ squeeze (Squeeze)   │ (None, 16)        │          0 │ matmul[0][0]      │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_14          │ (None, 16)        │          0 │ squeeze[0][0]     │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_84 (Dense)    │ (None, 1)         │         17 │ dropout_14[0][0]  │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ classifier_head_1   │ (None, 1)         │          0 │ dense_84[0][0]    │
│ (Activation)        │                   │            │                   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 23,024 (89.96 KB)
 Trainable params: 23,006 (89.87 KB)
 Non-trainable params: 18 (96.00 B)

在新数据上的推理

要获取新样本的预测,您可以简单地调用 model.predict()。您只需要做两件事:

  1. 将标量包装到列表中,以便拥有批次维度(模型只处理数据批次,而不是单个样本)。
  2. 对每个特征调用 convert_to_tensor
sample = {
    "age": 60,
    "sex": 1,
    "cp": 1,
    "trestbps": 145,
    "chol": 233,
    "fbs": 1,
    "restecg": 2,
    "thalach": 150,
    "exang": 0,
    "oldpeak": 2.3,
    "slope": 3,
    "ca": 0,
    "thal": "fixed",
}


# Given the category (in the sample above - key) and the category value (in the sample above - value),
# we return its one-hot encoding
def get_cat_encoding(cat, cat_value):
    # Create a list of zeros with the same length as categories
    encoding = [0] * len(cat)
    # Find the index of category_value in categories and set the corresponding position to 1
    if cat_value in cat:
        encoding[cat.index(cat_value)] = 1
    return encoding


for name, value in sample.items():
    if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
        sample.update(
            {
                name: get_cat_encoding(
                    CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
                )
            }
        )
# Convert inputs to tensors
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = best_model.predict(input_dict)

print(
    f"This particular patient had a {100 * predictions[0][0]:.1f} "
    "percent probability of having a heart disease, "
    "as evaluated by our model."
)

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 136ms/step



1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 162ms/step

This particular patient had a 28.1 percent probability of having a heart disease, as evaluated by our model.