Flax 模块生命周期#
本设计说明适用于已经熟悉 Flax Linen 模块但想了解该抽象背后的设计原则的用户。本说明应该让您很好地理解模块 API 构建所基于的假设和保证。如果您还没有模块的实际经验,请查看快速入门指南。
Flax Linen 模块在 Flax 核心之上提供了 Pythonic 抽象。模块抽象允许您在 JAX 之上创建具有状态、参数和随机性的类。这是一个关于 Module 类的设计和行为的实用指南。最后,您应该可以轻松地脱离常规,以新的方式使用模块。
概述#
定义#
让我们从模块生命周期的高级概述开始。首先,定义一个简单的模块
class MLP(nn.Module):
  # 1. Attribute annotations
  hidden_size: int
  out_size: int
  # 2. The ``setup`` method
  def setup(self):
    self.hidden = nn.Dense(self.hidden_size)
    self.out = nn.Dense(self.out_size)
  # 3. User methods
  def __call__(self, x):
    a = self.hidden(x)
    h = nn.relu(a)
    return self.out(h)
这个模块由以下部分组成:
- 属性注释,定义为数据类字段。这些注释自动定义一个构造函数。 
- ``setup`` 方法,用于创建子模块并将它们分配给属性。 
- 用户方法。按照惯例,大多数模块只有一个 - __call__方法,但您可以定义多个方法或使用不同的方法名称。
构造/初始化#
现在我们想构造和使用 MLP 模块
mlp = MLP(hidden_size=5, out_size=3)
x = jax.numpy.ones((1, 2))
variables = mlp.init(random.key(0), x)
y = mlp.apply(variables, x)
首先,我们构造一个 MLP 的实例并传递构造属性。请注意,如果您不习惯函数式编程模式,这里的构造与您期望的有所不同。MLP 构造函数实际上不创建任何变量或任何内部状态。最好将其视为模块的规范或模板,其中包含功能但没有数据。
让我们仔细看看初始化。令人惊讶的是,Flax 中实际上没有单独的初始化路径。调用 init 只是 apply 的一种特殊情况,您也可以将其写成
# equivalent to: variables = mlp.init(random.key(0), x)
_, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True)
因此,init 只是 apply 的一个包装器,其中
- 我们调用一个没有初始变量的模块(一个空字典)。 
- 始终传递一个名为 - "params"的 PRNG 生成器,用于随机初始化参数(使用参数初始化函数)。
- 所有变量集合都设置为可变( - mutable=True)。当集合是可变的时,可以更新现有变量,并且可以创建新变量。因此,在- init内部,可以在任何变量集合中初始化变量,并且它们都添加到返回的变量字典中。
生命周期#
现在您已经了解了 init 是 apply 的一种特殊情况,让我们更详细地看看 .apply(...)。实际上,模块的大部分复杂性都存在于 apply 方法中。“模块生命周期”包括构造和 apply -ing 一个模块。我们可以将模块生命周期总结如下:
- 我们构造 - mlp = MLP(hidden_size=5, out_size=3),使得- mlp.hidden_size=5和- mlp.out_size=3。
- 然后,调用 - mlp.apply,它会- 创建一个 - mlp的克隆,我们称之为- mlp_copy。
- 调用 - mlp_copy.setup()。
- 返回 - mlp_copy.__call__()的输出,并可选地返回使用关键字参数- mutable=指定为可变的变量集合。
 
请注意,生命周期包括克隆模块实例。这样做是为了确保 apply 可以被视为纯函数(即,如果您传入相同的参数,它将返回相同的输出)。您将在后面的顶级模块部分中了解更多详细信息。
变量#
“变量”这个词在编程和数学中无处不在。但是,重要的是要很好地理解变量在 JAX 和 Flax 的上下文中的含义。在 Flax 模块内部,变量的行为就像您对 Python 的期望一样。它们被初始化一次,读取,甚至可能不时更新。但是,JAX 没有变量的概念。相反,值存储在类似于 NumPy 数组的数组中 - 但有一个重要的区别:它们是不可变的。
init 和 apply 方法将变量作为嵌套字典返回,其中字符串键和 JAX 数组位于叶节点。在顶层,每个键对应一个变量集合。在每个集合内部,嵌套的字典结构与 Module 层次结构相对应。变量字典是不可变的,因此实际上只是变量所处状态的快照。当再次调用 apply 时,变量字典将作为参数传递。这样,变量的状态与上一次 init / apply 调用完成时的状态相同。
注意
模块字段使用 field_name: TypeHint 语法声明(与数据类相同)。如果没有类型提示,则属性被视为该类的静态属性。如果您无法指定类型,则可以使用 typing.Any 作为通配符类型。
紧凑模块#
Linen 提供了一种替代 API,可以更紧凑地定义模块。这对于模块仅包含一个使用参数和/或子模块的方法的常见情况尤其有用。使用紧凑的 API,可以将 MLP 重写如下:
class CompactMLP(nn.Module):
  hidden_size: int
  out_size: int
  @nn.compact
  def __call__(self, x):
    a = nn.Dense(self.hidden_size)(x)
    h = nn.relu(a)
    return nn.Dense(self.out_size)(h)
一个紧凑的 Module 在精神上类似于一个函数。它提供了一种简洁的表示法,并将外部交互限制在函数的输入和返回值上。在这种情况下,简洁的表示法可能使其他人更容易理解模块的作用。无需在 setup 和 __call__ 方法之间来回跳转来理解子模块正在做什么。相反,只需从上到下阅读一次 __call__ 方法,就可以获得一个简洁的概述。如果您正在实现具有许多超参数的复杂模块,这可能会产生重大影响。有关如何在 setup 和 compact 之间做出选择的实用指南,请参阅setup 或 compact。
内联定义子模块和/或变量的另一个好处是,您可以在构造变量时向方法添加参数。最常见的例子是使用形状信息来确定参数的形状,如下所示
class CompactScaledMLP(nn.Module):
  hidden_size: int
  out_size: int
  @nn.compact
  def __call__(self, x):
    scale = self.param("scale", nn.initializers.ones_init(), x.shape[-1:])
    x *= scale[None]
    a = nn.Dense(self.hidden_size)(x)
    h = nn.relu(a)
    return nn.Dense(self.out_size)(h)
许多标准的 Linen 模块,如 nn.Dense,已经使用了形状推断,从而避免了指定输入形状(如 Dense 层的输入特征数量)的需求。
紧凑的控制流#
如果您没有显式提供子模块的名称(使用传递给模块构造函数的 name= 关键字参数),则您定义子模块的顺序将决定子模块的名称。由于 name 决定了参数如何映射到子模块,因此您必须小心地将控制流与自动生成的名称混合使用。使用控制流可能会更改顺序或完全删除某些子模块。如果子模块仅应在某些构造参数存在的情况下才存在,则此功能很有用。但是,当控制流取决于模块的输入参数时,您应该小心。例如,以下模块会中断
class WrongModule(nn.Module):
  @nn.compact
  def __call__(self, x, mode):
    if mode == "encode":
      return nn.Dense(features=8)(x)
    elif mode == "decode":
      return nn.Dense(features=4)(x)
上面的模块会中断,因为编码器或解码器路径都会构造一个名为“Dense_0”的模块。这意味着这两个模块将共享参数,这不是预期的。实际上,这两个模块不能共享参数,因为它们各自具有不同的特征数量。
- 这个问题可以通过多种方式解决
- 提供显式名称 
- 在 - setup中创建模块
- 或将构造函数移出控制流。 
 
后者如下所示
class CorrectModule(nn.Module):
  @nn.compact
  def __call__(self, x, mode):
    encoder = nn.Dense(8)
    decoder = nn.Dense(4)
    if mode == "encode":
      return encoder(x)
    elif mode == "decode":
      return decoder(x)
在上面的示例中,构造顺序是固定的。构造完成后,可以以任意顺序使用子模块。
注意
紧凑的模块与 React hooks 非常相似。
顶层模块#
当在“顶层”创建模块实例时,它将处于“未绑定”状态,也就是说,它没有附加任何变量。“顶层”意味着它不是作为另一个模块类中的子模块构建的。除了调用 init 和 apply 之外,您对未绑定的模块无能为力。还要注意,不会在未绑定的模块上调用 setup,因此您只能访问构造参数。请参阅 未来工作 部分,了解将来这种情况可能会如何变化。
为什么顶层模块始终未绑定?#
当我们调用 apply 时,会创建一个顶层模块的副本,该副本实际上将保存变量和 PRNG 序列。这种有状态的“绑定”克隆仅在我们执行 apply 方法时存在。这样做的原因是,如果您创建一个有状态的对象并在 apply 函数返回之前将其销毁,则 apply 函数本身的行为就像一个纯函数。纯函数有两个约束
- 如果您输入相同的参数,它将返回相同的输出 
- 它不会更改函数外部的任何内容。这意味着您不能操作在纯函数外部可访问的有状态对象。 
纯函数有很多优点,但在使用 JAX 时,它们通常是必不可少的。例如,大多数代码需要使用 jax.jit 进行编译才能快速运行,并且一旦创建了模块,您可能希望使用 jax.grad 优化其参数。但是,这些 API 需要一个纯函数,并且不能直接在有状态的绑定 Module 实例上工作。此外,纯函数允许与其他库进行灵活的互操作性。例如,我们建议使用 Optax 来优化参数。Optax 中的优化器期望并返回 JAX 数组的 PyTree 以进行优化,就像 Linen 模块的 apply 函数一样。
克隆#
为了使这种方法可靠地工作,我们需要明确定义的克隆行为。Flax 没有像 Python 的 deepcopy 那样依赖复杂的嵌套克隆过程,而是强制执行 Module 完全由其构造参数定义。因此,克隆模块简化为使用其原始构造参数调用构造函数。由于 Module 充当不可变的 dataclass,因此构造参数直接映射到实例属性。在 setup 或 __post_init__ 中计算的非构造属性也应仅依赖于构造参数,以确保明确定义的克隆。
设置#
setup 方法通常用作普通 Python 类中的构造函数钩子 (__init__)。但是,对于更高级的用例,最好意识到它与构造函数并不完全相同。
setup 仅在模块绑定后才会被调用。通常,这不是问题,因为大多数模块会(几乎)立即绑定(作为 init 和 apply 的一部分)。在 setup 中,子模块在分配给属性时会变为绑定。在 nn.compact 修饰的方法中,子模块会在构造时立即绑定。如上一节所述,顶层模块永远不会绑定,因此在构造时不会调用 setup。这意味着您无法从未绑定的顶层模块访问在 setup 中分配的属性。
class TopLevelAccess(nn.Module):
  def setup(self):
    self.foo = nn.Dense(2)
mdl = TopLevelAccess()
assert not hasattr(mdl, "foo")  # foo is not defined because setup is not called
setup 方法不是在 Module 绑定后立即调用,而是在您与 Module 实例交互时才调用(例如:调用方法或访问属性)。这不应影响 Module 的行为,但延迟执行有时会影响调试期间的日志语句和堆栈跟踪。有关 函数化 的部分将解释为什么我们需要 setup 首先是延迟的。
函数化#
到目前为止,我们有一个纯 apply 函数,它通常使用一些 JAX 转换进行转换,并且在 apply 内部,我们有一个有状态的模块实例可以使用。换句话说:在模块外部,我们处于一个函数式世界中,我们可以利用 JAX 的函数式转换,而在模块内部,我们可以利用 Flax 的有状态变量和 PRNG 序列,并且 apply 方法是我们这两个世界之间的桥梁。
但是,如果我们想在模块内部使用 JAX 转换怎么办?答案是函数化。
此过程本身很繁琐且容易出错,但由 Flax 在内部处理。在高层,我们可以将其总结如下。对于在模块中定义的方法 fn
- 收集应在 JAX 转换内部可用的模块的状态(变量和 PRNG 序列)并拍摄快照。 
- 使用原始参数和收集的状态调用 JAX 转换。然后在转换内部 - 解包状态并重新创建模块 
- 调用用户代码 - fn
- 收集更新的变量和 rng,并将其与 - fn中的原始返回值一起返回
 
- 使用从转换返回的更新状态更新原始状态。 
有关函数化和提升的更深入解释,请参阅提升的转换设计说明。
实际影响#
在大多数情况下,函数化是自动为你处理的。但仍然有一些约束需要你考虑。最重要的是,Flax 只处理有状态的原语(Linen 变量和 RNG),而不是任意的有状态 Python 代码。最重要的是:你不能闭包有状态对象和 Module 对象,因为它们对 Flax 的内部机制(以及一般的 JAX)是不可见的。
class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda x: dense(x) + 1
    # simply calling inner works fine
    # return self.inner(x, fn)
    # but applying a transformation doesn't:
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, fn)
  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x
这里 inner 接受一个闭包 Module 实例的函数。在这个例子中,这可以正常工作,因为我们没有使用提升的转换来转换 inner 方法。大多数方法不会被转换,但了解如何使 Module 方法可转换是很好的。
可转换性的主要障碍是 JAX 不识别的类型。JAX 只理解 Pytree 参数;即 (Jax) numpy ndarrays 和 Python 数字/布尔值的任意嵌套的 Python 容器(字典、列表、元组)。Flax 允许使用 flax.struct API 定义与 Pytree 兼容的数据类。
函数闭包是意外地从转换中隐藏 JAX 数组或 Linen Module 的最常见方式。但是,如果你想传递也与 JAX 和 Linen 转换兼容的闭包,则有一个简单的解决方法
class Partial(flax.struct.PyTreeNode):
  fn: Callable = flax.struct.field(pytree_node=False)
  args: Iterable[Any]
  def __call__(self, *args, **kwargs):
    return self.fn(*(tuple(self.args) + args), **kwargs)
class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda mdl, x: mdl(x) + 1
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, Partial(fn, [dense]))
  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x
这里,闭包是使用 Flax 数据类实现的。函数本身用 flax.struct.field(pytree_node=False) 注释,以表明它不包含 JAX 数组或 Linen Module。另一方面,部分应用的 args 被视为 pytree 容器。我们将闭包重写为使用 Partial。现在可以使用提升的转换来转换 inner 方法。
未来工作#
为未绑定模块设置#
当涉及到构造后初始化字段时,当前的 Module 抽象尤其具有限制性。在当前的 Module API 中,setup 方法是初始化 Module 实例字段的地方。由于 setup 仅在绑定 Module 上调用,因此完整的 Module API 在 setup 内部可用,包括变量声明。但是,通常我们实际上不需要任何有状态的 API 来初始化字段。事实上,最常见的情况是我们只是想声明一个子模块。更重要的是,检查子模块以进行调试或部分运行模型通常很有用。例如考虑
class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder(...)
    self.decoder = Decoder(...)
想象一下,我们只想使用 auto_encoder.decoder.apply(decoder_variables, x) 调用解码器。使用当前的 setup API,这是行不通的,因为我们必须先绑定变量,然后才能调用 setup 并定义解码器属性。当然,我们可以使用与 setup 中相同的属性手动构造 Decoder Module,但这在许多情况下并不理想。
有两种可能的解决方案可以使此用例更符合人体工程学。首先,可以在绑定之前立即在构造后运行 setup。这意味着你仍然可以创建子模块,但你不能再定义或操作变量。因此,这将是一个重大更改,并且需要一个新的 API 来延迟定义变量
或者,可以引入一个额外的特殊方法,该方法在 Module 构造后并在绑定之前立即运行。在这种情况下,setup 方法将保留其原始语义。
