Keras 3 API 文档 / Ops API / 核心 Ops

核心操作

[源文件]

associative_scan 函数

keras.ops.associative_scan(f, elems, reverse=False, axis=0)

使用并行关联二元运算执行扫描。

此操作与 scan 类似,主要区别在于 associative_scan 是一个并行实现,可能带来显著的性能优势,尤其是在 jit 编译时。限制在于,它只能在 f 是二元关联运算时使用(即必须验证 f(a, f(b, c)) == f(f(a, b), c))。

有关关联扫描的介绍,请参阅本文献:Blelloch, Guy E. 1990. 前缀和及其应用

参数

  • f:实现具有签名 r = f(a, b) 的关联二元运算的 Python 可调用对象。函数 f 必须是关联的,即它必须满足方程 f(a, f(b, c)) == f(f(a, b), c)。输入和结果是与 elems 匹配的数组(可能是嵌套的 Python 树结构)。每个数组都有一个维度,该维度取代了 axis 维度。f 应该沿着 axis 维度逐元素应用。结果 r 具有与两个输入 ab 相同的形状(和结构)。
  • elems:数组(可能是嵌套的 Python 树结构),每个数组都有一个大小为 num_elemsaxis 维度。
  • reverse:一个布尔值,指示扫描是否应相对于 axis 维度反向执行。
  • axis:一个整数,标识应该进行扫描的轴。

返回

具有与 elems 相同的形状和结构的数组(可能是嵌套的 Python 树结构),其中 axis 的第 k 个元素是沿着 axis 递归应用 f 来组合 elems 的前 k 个元素的结果。例如,给定 elems = [a, b, c, ...],结果将是 [a, f(a, b), f(f(a, b), c), ...]

示例

>>> sum_fn = lambda x, y: x + y
>>> xs = keras.ops.arange(5)
>>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
>>> ys
[0, 1, 3, 6, 10]
>>> sum_fn = lambda x, y: [x[0] + y[0], x[1] + y[1], x[2] + y[2]]
>>> xs = [keras.ops.array([[1, 2]]) for _ in range(3)]
>>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
>>> ys
[[1, 3], [1, 3], [1, 3]]

[源文件]

cast 函数

keras.ops.cast(x, dtype)

将张量转换为所需的 dtype。

参数

  • x:张量或变量。
  • dtype:目标类型。

返回

指定 dtype 的张量。

示例

>>> x = keras.ops.arange(4)
>>> x = keras.ops.cast(x, dtype="float16")

[源文件]

cond 函数

keras.ops.cond(pred, true_fn, false_fn)

有条件地应用 true_fnfalse_fn

参数

  • pred:布尔标量类型
  • true_fn:当 pred == True 时返回输出的可调用对象。
  • false_fn:当 pred == False 时返回输出的可调用对象。

返回

取决于 pred 的 true_fnfalse_fn 的输出。


[源文件]

convert_to_numpy 函数

keras.ops.convert_to_numpy(x)

将张量转换为 NumPy 数组。

参数

  • x:张量。

返回

NumPy 数组。


[源文件]

convert_to_tensor 函数

keras.ops.convert_to_tensor(x, dtype=None, sparse=None, ragged=None)

将 NumPy 数组或 Python 数组转换为张量。

当前后端对应的原生张量,除非设置了 dtypesparseragged 参数,否则保持不变。

参数

  • x:NumPy 数组、Python 数组(可以是嵌套的)或后端张量。
  • dtype:目标类型。如果为 None,则使用 x 的类型。
  • sparse:是否保留稀疏张量。False 将导致稀疏张量密集化。默认值 None 表示仅当后端支持时才保留稀疏张量。
  • ragged:是否保留不规则张量。False 将导致不规则张量密集化。默认值 None 表示仅当后端支持时才保留不规则张量。

返回

指定 dtype 和稀疏性的后端张量。

示例

>>> x = np.array([1, 2, 3])
>>> y = keras.ops.convert_to_tensor(x)

[源文件]

custom_gradient 函数

keras.ops.custom_gradient(f)

用于定义具有自定义梯度的函数的装饰器。

此装饰器允许对操作序列的梯度进行细粒度控制。这可能因多种原因而有用,包括为操作序列提供更高效或数值更稳定的梯度。

参数

  • f:函数 f(*args),返回一个元组 (output, grad_fn),其中
    • args 是函数的张量输入的序列(嵌套结构)。
    • output 是将 forward_fn 中的操作应用于 args 后产生的张量输出(嵌套结构)。
    • grad_fn 是一个具有签名 grad_fn(*args, upstream) 的函数,它返回一个与(展平的)args 大小相同的张量元组:output 中的张量相对于 args 中的张量的导数。upstream 是一个张量或张量序列,它持有 output 中每个张量的初始值梯度。

返回

函数 h(*args) 返回与 f(*args)[0] 相同的值,其梯度由 f(*args)[1] 确定。

示例

  1. 后端无关示例。
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)

    def grad(*args, upstream=None):
        if upstream is None:
            (upstream,) = args
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))

    return ops.log(1 + e), grad

注意,返回梯度计算的 grad 函数需要 argsupstream 关键字参数,具体取决于设置的后端。对于 JAX 和 TensorFlow 后端,它只需要一个参数,而对于 PyTorch 后端,它可能使用 upstream 参数。

使用 TensorFlow/JAX 后端时,grad(upstream) 就足够了。对于 PyTorch,grad 函数需要 *argsupstream,例如 def grad(*args, upstream)。请按照前面的示例使用 @ops.custom_gradient,以兼容所有后端。

  1. 以下是 JAX 和 TensorFlow 特定示例
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)
    def grad(upstream):
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
    return ops.log(1 + e), grad
  1. 最后,这里是 PyTorch 特定示例,使用 *argsupstream
@ops.custom_gradient
def log1pexp(x):
    e = ops.exp(x)
    def grad(*args, upstream):
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
    return ops.log(1 + e), grad

[源文件]

dtype 函数

keras.ops.dtype(x)

将张量输入的 dtype 作为标准化字符串返回。

请注意,由于标准化,dtype 将与后端特定的 dtype 版本不相等。

参数

  • x:张量。此函数将尝试访问输入张量的 dtype 属性。

返回

表示输入张量 dtype 的字符串,例如 "float32"

示例

>>> x = keras.ops.zeros((8, 12))
>>> keras.ops.dtype(x)
'float32'

[源文件]

erf 函数

keras.ops.erf(x)

计算 x 的误差函数,逐元素进行。

参数

  • x:输入张量。

返回

x 具有相同 dtype 的张量。

示例

>>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0])
>>> keras.ops.erf(x)
array([-0.99998 , -0.99532, -0.842701,  0.,  0.842701], dtype=float32)

[源文件]

erfinv 函数

keras.ops.erfinv(x)

计算 x 的逆误差函数,逐元素进行。

参数

  • x:输入张量。

返回

x 具有相同 dtype 的张量。

示例

>>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3])
>>> keras.ops.erfinv(x)
array([-0.47694, -0.17914, -0.08886,  0. ,  0.27246], dtype=float32)

[源文件]

extract_sequences 函数

keras.ops.extract_sequences(x, sequence_length, sequence_stride)

将最后一个轴的维度扩展为长度为 sequence_length 的序列。

使用大小为 sequence_length 的窗口,以 sequence_stride 为步幅在输入的最后一个轴上滑动,并将最后一个轴替换为 [num_sequences, sequence_length] 序列。

如果沿最后一个轴的维度为 N,则序列数量可以通过以下方式计算

num_sequences = 1 + (N - sequence_length) // sequence_stride

参数

  • x:输入张量。
  • sequence_length:表示序列长度的整数。
  • sequence_stride:表示序列跳跃大小的整数。

返回

形状为 [..., num_sequences, sequence_length] 的序列张量。

示例

>>> x = keras.ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
>>> extract_sequences(x, 3, 2)
array([[1, 2, 3],
   [3, 4, 5]])

[源文件]

fori_loop 函数

keras.ops.fori_loop(lower, upper, body_fun, init_val)

For 循环实现。

参数

  • lower:循环变量的初始值。
  • upper:循环变量的上限。
  • body_fun:表示循环体的可调用对象。必须接受两个参数:循环变量和循环状态。循环状态应该由此函数更新并返回。
  • init_val:循环状态的初始值。

返回

循环结束后的最终状态。

示例

>>> lower = 0
>>> upper = 10
>>> body_fun = lambda i, s: (i + 1, s + i)
>>> init_val = 0
>>> keras.ops.fori_loop(lower, upper, body_fun, init_val)
45

[源文件]

in_top_k 函数

keras.ops.in_top_k(targets, predictions, k)

检查目标是否在前 k 个预测中。

参数

  • targets:真实标签的张量。
  • predictions:预测标签的张量。
  • k:表示要考虑的预测数量的整数。

返回

targets 形状相同的布尔张量,其中每个元素指示相应的目标是否在前 k 个预测中。

示例

>>> targets = keras.ops.convert_to_tensor([2, 5, 3])
>>> predictions = keras.ops.convert_to_tensor(
... [[0.1, 0.4, 0.6, 0.9, 0.5],
...  [0.1, 0.7, 0.9, 0.8, 0.3],
...  [0.1, 0.6, 0.9, 0.9, 0.5]])
>>> in_top_k(targets, predictions, k=3)
array([ True False  True], shape=(3,), dtype=bool)

[源文件]

is_tensor 函数

keras.ops.is_tensor(x)

检查给定对象是否为张量。

注意:这会检查特定于后端的张量,因此如果您的后端是 PyTorch 或 JAX,则传递 TensorFlow 张量将返回 False

参数

  • x:变量。

返回

如果 x 是张量,则为 True,否则为 False


[源文件]

logsumexp 函数

keras.ops.logsumexp(x, axis=None, keepdims=False)

计算张量中元素指数和的对数。

参数

  • x:输入张量。
  • axis:一个整数或整数元组,指定计算和的轴。如果为 None,则计算所有元素的和。默认为 None
  • keepdims:一个布尔值,指示在计算和时是否保留输入张量的维度。默认为 False

返回

包含 x 中元素指数和的对数的张量。

示例

>>> x = keras.ops.convert_to_tensor([1., 2., 3.])
>>> logsumexp(x)
3.407606

[源文件]

map 函数

keras.ops.map(f, xs)

在领先的数组轴上映射函数。

类似于 Python 的内置 map,但输入和输出是堆叠数组的形式。考虑使用 vectorized_map() 变换代替,除非您需要逐元素应用函数以减少内存使用或与其他控制流原语进行异构计算。

xs 是数组类型时,map() 的语义由以下 Python 实现给出

def map(f, xs):
    return np.stack([f(x) for x in xs])

参数

  • f:定义了在 xs 的第一个轴或多个轴上逐元素应用的函数的可调用对象。
  • xs:要沿着主轴映射的值。

返回

映射后的值。

示例

>>> f = lambda x: x**2
>>> xs = keras.ops.arange(10)
>>> ys = keras.ops.map(f, xs)
>>> ys
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
>>> f = lambda x: {"y1": x**2, "y2": x * 10}  # Can have nested outputs
>>> ys = keras.ops.map(f, xs)
>>> ys["y1"]
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
>>> ys["y2"]
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90]

[源文件]

rsqrt 函数

keras.ops.rsqrt(x)

逐元素计算 x 的平方根的倒数。

参数

  • x:输入张量

返回

x 具有相同 dtype 的张量。

示例

>>> x = keras.ops.convert_to_tensor([1.0, 10.0, 100.0])
>>> keras.ops.rsqrt(x)
array([1.0, 0.31622776, 0.1], dtype=float32)

[源文件]

saturate_cast 函数

keras.ops.saturate_cast(x, dtype)

执行到所需 dtype 的安全饱和转换。

饱和转换可防止在转换为值范围较小的 dtype 时发生数据类型溢出。例如,ops.cast(ops.cast([-1, 256], "float32"), "uint8") 返回 [255, 0],但 ops.saturate_cast(ops.cast([-1, 256], "float32"), "uint8") 返回 [0, 255]

参数

  • x:张量或变量。
  • dtype:目标类型。

返回

指定 dtype 的安全转换张量。

示例

使用双三次插值调整图像大小时,可能会产生超出原始范围的值。

>>> image2x2 = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1)
>>> image4x4 = tf.image.resize(image2x2, (4, 4), method="bicubic")
>>> print(image4x4.numpy().squeeze())
>>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ]
>>> #  [ 52.526054  52.82143   53.407146  53.70253 ]
>>> #  [201.29752  201.59288  202.17859  202.47395 ]
>>> #  [276.32355  276.61893  277.20465  277.50006 ]]

将此调整大小后的图像重新转换为 uint8 将导致溢出。

>>> image4x4_casted = ops.cast(image4x4, "uint8")
>>> print(image4x4_casted.numpy().squeeze())
>>> # [[234 234 235 235]
>>> #  [ 52  52  53  53]
>>> #  [201 201 202 202]
>>> #  [ 20  20  21  21]]

饱和转换为 uint8 会在转换前将值裁剪到 uint8 范围,并且不会导致溢出。

>>> image4x4_saturate_casted = ops.saturate_cast(image4x4, "uint8")
>>> print(image4x4_saturate_casted.numpy().squeeze())
>>> # [[  0   0   0   0]
>>> #  [ 52  52  53  53]
>>> #  [201 201 202 202]
>>> #  [255 255 255 255]]

[源文件]

scan 函数

keras.ops.scan(f, init, xs=None, length=None, reverse=False, unroll=1)

在领先的数组轴上扫描函数并携带状态。

xs 的类型是数组类型或 None,且 ys 的类型是数组类型时,scan() 的语义大致由以下 Python 实现给出

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

循环携带值 carry (init) 在所有迭代中必须保持固定的形状和 dtype。

在 TensorFlow 中,y 的形状和 dtype 必须与 carry 匹配。在其他后端中则无需如此。

参数

  • f:定义每个循环迭代逻辑的可调用对象。它接受两个参数,第一个是循环携带值,第二个是 xs 沿其主轴的切片。此可调用对象返回一个对,其中第一个表示循环携带的新值,第二个表示输出的切片。
  • init:循环携带的初始值。这可以是标量、张量或任何嵌套结构。它必须与 f 返回的第一个元素的结构匹配。
  • xs:可选值,用于沿其主轴扫描。这可以是张量或任何嵌套结构。如果未提供 xs,则必须指定 length 来定义循环迭代次数。默认为 None
  • length:可选整数,指定循环迭代次数。如果未提供 length,则默认为 xs 中数组主轴的大小。默认为 None
  • reverse:可选布尔值,指定是正向还是反向运行扫描迭代,相当于反转 xsys 中数组的主轴。
  • unroll:可选的正整数或布尔值,指定在循环的单次迭代中展开多少次扫描迭代。如果提供整数,则确定在循环的单次卷绕迭代中运行多少次展开的循环迭代。如果提供布尔值,则确定循环是完全展开 (unroll=True) 还是完全不展开 (unroll=False)。请注意,仅 JAX 和 TensorFlow 后端支持展开。

返回

一个对,其中第一个元素表示最终的循环携带值,第二个元素表示当在输入的主轴上扫描时 f 的堆叠输出。

示例

>>> sum_fn = lambda c, x: (c + x, c + x)
>>> init = keras.ops.array(0)
>>> xs = keras.ops.array([1, 2, 3, 4, 5])
>>> carry, result = keras.ops.scan(sum_fn, init, xs)
>>> carry
15
>>> result
[1, 3, 6, 10, 15]

[源文件]

scatter 函数

keras.ops.scatter(indices, values, shape)

返回形状为 shape 的张量,其中 indices 设置为 values

从高层来看,此操作执行 zeros[indices] = updates 并返回输出。它等价于

zeros = keras.ops.zeros(shape)
output = keras.ops.scatter_update(zeros, indices, values)

参数

  • indices:一个张量或列表/元组,指定 values 中值的索引。
  • values:一个张量,要在 indices 处设置的值。
  • shape:输出张量的形状。

示例

>>> indices = [[0, 1], [1, 1]]
>>> values = np.array([1., 1.])
>>> keras.ops.scatter(indices, values, shape=(2, 2))
array([[0., 1.],
       [0., 1.]])

[源文件]

scatter_update 函数

keras.ops.scatter_update(inputs, indices, updates)

通过在散布(稀疏)索引处进行更新来更新输入。

从高层来看,此操作执行 inputs[indices] = updates。假设 inputs 是形状为 (D0, D1, ..., Dn) 的张量,scatter_update 有 2 种主要用法。

  1. indices 是形状为 (num_updates, n) 的二维张量,其中 num_updates 是要执行的更新次数,updates 是形状为 (num_updates,) 的一维张量。例如,如果 inputszeros((4, 4, 4)),并且我们想将 inputs[1, 2, 3]inputs[0, 1, 3] 更新为 1,则可以使用
inputs = np.zeros((4, 4, 4))
indices = [[1, 2, 3], [0, 1, 3]]
updates = np.array([1., 1.])
inputs = keras.ops.scatter_update(inputs, indices, updates)

2 indices 是形状为 (num_updates, k) 的二维张量,其中 num_updates 是要执行的更新次数,k (k < n) 是 indices 中每个索引的大小。updates 是形状为 (num_updates, inputs.shape[k:])n - k 维张量。例如,如果 inputs = np.zeros((4, 4, 4)),并且我们想将 inputs[1, 2, :]inputs[2, 3, :] 更新为 [1, 1, 1, 1],则 indices 的形状将是 (num_updates, 2) (k = 2),而 updates 的形状将是 (num_updates, 4) (inputs.shape[2:] = 4)。请参阅下面的代码

inputs = np.zeros((4, 4, 4))
indices = [[1, 2], [2, 3]]
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
inputs = keras.ops.scatter_update(inputs, indices, updates)

参数

  • inputs:一个张量,即要更新的张量。
  • indices:形状为 (N, inputs.ndim) 的张量或列表/元组,指定要更新的索引。N 是要更新的索引数量,必须等于 updates 的第一个维度。
  • updates:一个张量,要放入 inputsindices 处的新值。

返回

一个张量,具有与 inputs 相同的形状和 dtype。


[源文件]

segment_max 函数

keras.ops.segment_max(data, segment_ids, num_segments=None, sorted=False)

计算张量中各段的最大值。

参数

  • data:输入张量。
  • segment_ids:一个 N 维张量,包含 data 中每个元素的段索引。data.shape[:len(segment_ids.shape)] 应该匹配。
  • num_segments:一个整数,表示总段数。如果未指定,则根据 segment_ids 中的最大值推断。
  • sorted:一个布尔值,指示 segment_ids 是否已排序。默认为 False

返回

包含各段最大值的张量,其中每个元素表示 data 中相应段的最大值。

示例

>>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
>>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> keras.ops.segment_max(data, segment_ids, num_segments)
array([2, 20, 200], dtype=int32)

[源文件]

segment_sum 函数

keras.ops.segment_sum(data, segment_ids, num_segments=None, sorted=False)

计算张量中各段的总和。

参数

  • data:输入张量。
  • segment_ids:一个 N 维张量,包含 data 中每个元素的段索引。segment id 的维度数量应严格小于或等于 data 的维度数量。
  • num_segments:一个整数,表示总段数。如果未指定,则根据 segment_ids 中的最大值推断。
  • sorted:一个布尔值,指示 segment_ids 是否已排序。默认为 False

返回

包含各段总和的张量,其中每个元素表示 data 中相应段的总和。

示例

>>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
>>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> keras.ops.segment_sum(data, segment_ids,num_segments)
array([3, 30, 300], dtype=int32)

[源文件]

shape 函数

keras.ops.shape(x)

获取张量输入的形状。

注意:在 TensorFlow 后端上,当 x 是具有动态形状的 tf.Tensor 时,在编译函数的上下文中是动态的维度将具有 tf.Tensor 值,而不是静态整数值。

参数

  • x:一个张量。此函数将尝试访问输入张量的 shape 属性。

返回

由整数或 None 值组成的元组,表示输入张量的形状。

示例

>>> x = keras.ops.zeros((8, 12))
>>> keras.ops.shape(x)
(8, 12)

[源文件]

slice 函数

keras.ops.slice(inputs, start_indices, shape)

返回输入张量的切片。

从高层来看,此操作是数组切片(例如 inputs[start_indices: start_indices + shape])的显式替代。与通过方括号进行的切片不同,此操作将接受所有后端上的张量起始索引,这在使用通过其他张量操作动态计算的索引时非常有用。

inputs = np.zeros((5, 5))
start_indices = np.array([3, 3])
shape = np.array([2, 2])
inputs = keras.ops.slice(inputs, start_indices, shape)

参数

  • inputs:一个张量,即要更新的张量。
  • start_indices:形状为 (inputs.ndim,) 的列表/元组,指定更新的起始索引。
  • shape:返回切片的完整形状。

返回

一个张量,具有与 inputs 相同的形状和 dtype。


[源文件]

slice_update 函数

keras.ops.slice_update(inputs, start_indices, updates)

通过切片一个包含更新值的张量来更新输入。

从高层来看,此操作执行 inputs[start_indices: start_indices + updates.shape] = updates。假设 inputs 是形状为 (D0, D1, ..., Dn) 的张量,start_indices 必须是 n 个整数组成的列表/元组,指定起始索引。updates 必须与 inputs 具有相同的秩,并且每个维度的尺寸不得超过 Di - start_indices[i]。例如,如果我们有一个 2D 输入 inputs = np.zeros((5, 5)),并且我们想将最后 2 行和最后 2 列的交集更新为 1,即 inputs[3:, 3:] = np.ones((2, 2)),则可以使用下面的代码

inputs = np.zeros((5, 5))
start_indices = [3, 3]
updates = np.ones((2, 2))
inputs = keras.ops.slice_update(inputs, start_indices, updates)

参数

  • inputs:一个张量,即要更新的张量。
  • start_indices:形状为 (inputs.ndim,) 的列表/元组,指定更新的起始索引。
  • updates:一个张量,要放入 inputsindices 处的新值。updates 必须与 inputs 具有相同的秩。

返回

一个张量,具有与 inputs 相同的形状和 dtype。


[源文件]

stop_gradient 函数

keras.ops.stop_gradient(variable)

停止梯度计算。

参数

  • variable:要禁用梯度计算的张量变量。

返回

禁用梯度计算的变量。

示例

>>> var = keras.backend.convert_to_tensor(
...     [1., 2., 3.],
...     dtype="float32"
... )
>>> var = keras.ops.stop_gradient(var)

[源文件]

switch 函数

keras.ops.switch(index, branches, *operands)

应用 index 指定的 branches 中的一个且仅一个。

如果 index 超出范围,则将其钳制到范围内。

switch 的语义大致由以下 Python 实现给出

def switch(index, branches, *operands):
    index = clamp(0, index, len(branches) - 1)
    return branches[index](*operands)

参数

  • index:指示应用哪个分支函数的整数标量。
  • branches:基于 index 应用的函数序列。
  • operands:应用于所选分支的输入。

返回

根据 index 选择的分支的 branch(*operands) 的输出。

示例

>>> add_fn = lambda x, y: x + y
>>> subtract_fn = lambda x, y: x - y
>>> x = keras.ops.array(2.0)
>>> y = keras.ops.array(0.5)
>>> branches = [add_fn, subtract_fn]
>>> keras.ops.switch(0, branches, x, y)
2.5
>>> keras.ops.switch(1, branches, x, y)
1.5

[源文件]

top_k 函数

keras.ops.top_k(x, k, sorted=True)

查找张量中的前 k 个值及其索引。

参数

  • x:输入张量。
  • k:表示要检索的前 k 个元素的数量的整数。
  • sorted:一个布尔值,指示是否按降序对输出进行排序。默认为 True

返回

包含两个张量的元组。第一个张量包含前 k 个值,第二个张量包含输入张量中前 k 个值的索引。

示例

>>> x = keras.ops.convert_to_tensor([5, 2, 7, 1, 9, 3])
>>> values, indices = top_k(x, k=3)
>>> print(values)
array([9 7 5], shape=(3,), dtype=int32)
>>> print(indices)
array([4 2 0], shape=(3,), dtype=int32)

[源文件]

unstack 函数

keras.ops.unstack(x, num=None, axis=0)

将秩为 R 的张量的给定维度拆分为秩为 (R-1) 的张量。

参数

  • x:输入张量。
  • num:维度轴的长度。如果为 None,则自动推断。
  • axis:沿哪个轴进行拆包。

返回

沿给定轴拆包的张量列表。

示例

>>> x = keras.ops.array([[1, 2], [3, 4]])
>>> keras.ops.unstack(x, axis=0)
[array([1, 2]), array([3, 4])]

[源文件]

vectorized_map 函数

keras.ops.vectorized_map(function, elements)

在张量 elements 的轴 0 上并行映射 function

从示意图来看,在单个张量输入 elements 的情况下,vectorized_map 实现以下功能

def vectorized_map(function, elements)
    outputs = []
    for e in elements:
        outputs.append(function(e))
    return stack(outputs)

在张量可迭代对象 elements 的情况下,它实现以下功能

def vectorized_map(function, elements)
    batch_size = elements[0].shape[0]
    outputs = []
    for index in range(batch_size):
        outputs.append(function([e[index] for e in elements]))
    return np.stack(outputs)

在这种情况下,function 期望接受单个张量参数列表作为输入。


[源文件]

while_loop 函数

keras.ops.while_loop(cond, body, loop_vars, maximum_iterations=None)

While 循环实现。

参数

  • cond:表示循环终止条件的可调用对象。必须接受类似 loop_vars 结构的参数。如果 loop_vars 是元组或列表,则 loop_vars 的每个元素将作为位置参数传递给可调用对象。
  • body:表示循环体的可调用对象。必须接受类似 loop_vars 结构的参数,并返回具有相同结构的更新值。如果 loop_vars 是元组或列表,则 loop_vars 的每个元素将作为位置参数传递给可调用对象。
  • loop_vars:在循环迭代中持续存在的张量状态的任意嵌套结构。
  • maximum_iterations:可选的最大 while 循环迭代次数。如果提供,则 cond 的输出将与一个附加条件进行 AND 运算,该条件确保执行的迭代次数不超过 maximum_iterations

返回

张量列表/元组,具有与 inputs 相同的形状和 dtype。

示例

>>> i = 0
>>> cond = lambda i: i < 10
>>> body = lambda i: i + 1
>>> keras.ops.while_loop(cond, body, i)
10
>>> x, y = 0, 1
>>> cond = lambda x, y: x < 10
>>> body = lambda x, y: (x + 1, y + 1)
>>> keras.ops.while_loop(cond, body, (x, y))
10, 11