flax.traverse_util 包
用于遍历不可变数据结构的实用程序。
Traversal 可用于迭代和更新复杂的数据结构。Traversals 接受一个对象并返回其内容的子集。例如,Traversal 可以选择对象的属性
>>> from flax import traverse_util
>>> import dataclasses
>>> @dataclasses.dataclass
... class Foo:
... foo: int = 0
... bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]
可以使用组合来构建更复杂的遍历。通常从标识遍历开始并使用方法链来构建预期的 Traversal 是有用的
>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]
Traversals 也可以使用 update 方法进行更改
>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}
Traversals 永远不会更改原始数据。因此,更新本质上返回包含所提供更新的数据副本。
遍历对象
-
class flax.traverse_util.Traversal(*args, **kwargs)[来源]
所有遍历的基类。
-
compose(other)[来源]
组合两个遍历。
-
each()[来源]
遍历所选容器中的每个项目。
-
filter(fn)[来源]
过滤所选值。
-
abstract iterate(inputs)[来源]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
merge(*traversals)[来源]
组合任意数量的遍历并合并结果。
-
set(values, inputs)[来源]
覆盖 Traversal 选择的值。
- 参数
values – 包含新值的列表。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
tree()[来源]
遍历 pytree 中的每个项目。
-
abstract update(fn, inputs)[来源]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
class flax.traverse_util.TraverseId(*args, **kwargs)[来源]
标识 Traversal。
-
iterate(inputs)[来源]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
update(fn, inputs)[来源]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
class flax.traverse_util.TraverseMerge(*args, **kwargs)[来源]
合并一组遍历的选择。
-
iterate(inputs)[源代码]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
update(fn, inputs)[源代码]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
class flax.traverse_util.TraverseCompose(*args, **kwargs)[源代码]
组合两个遍历。
-
iterate(inputs)[源代码]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
update(fn, inputs)[源代码]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
class flax.traverse_util.TraverseFilter(*args, **kwargs)[源代码]
根据谓词筛选选定的值。
-
iterate(inputs)[源代码]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
update(fn, inputs)[源代码]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
class flax.traverse_util.TraverseAttr(*args, **kwargs)[源代码]
遍历对象的属性。
-
iterate(inputs)[源代码]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
update(fn, inputs)[源代码]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
class flax.traverse_util.TraverseItem(*args, **kwargs)[源代码]
遍历对象的项。
-
iterate(inputs)[源代码]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
update(fn, inputs)[源代码]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
class flax.traverse_util.TraverseEach(*args, **kwargs)[源代码]
遍历容器的每个项。
-
iterate(inputs)[源代码]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
update(fn, inputs)[源代码]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
-
class flax.traverse_util.TraverseTree(*args, **kwargs)[源代码]
遍历 pytree 中的每个项。
-
iterate(inputs)[源代码]
迭代此 Traversal 选择的值。
- 参数
inputs – 应遍历的对象。
- 返回
遍历值的迭代器。
-
update(fn, inputs)[源代码]
更新焦点项。
- 参数
fn – 回调函数,将每个遍历项映射到其更新值。
inputs – 应遍历的对象。
- 返回
具有更新值的新对象。
字典工具
-
flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[源代码]
展平嵌套字典。
嵌套的键被展平为一个元组。请参阅 unflatten_dict,了解如何恢复嵌套的字典结构。
示例
>>> from flax.traverse_util import flatten_dict
>>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
>>> flat_xs = flatten_dict(xs)
>>> flat_xs
{('foo',): 1, ('bar', 'a'): 2}
请注意,空字典将被忽略,并且不会被 unflatten_dict 恢复。
- 参数
xs – 一个嵌套字典
keep_empty_nodes – 将空字典替换为 traverse_util.empty_node。
is_leaf – 一个可选的函数,它接受下一个嵌套字典和嵌套键,如果嵌套字典是一个叶子(即,不应进一步展平),则返回 True。
sep – 如果指定,则返回的字典的键将是 sep 连接的字符串(如果 None,则键将是元组)。
- 返回
扁平化的字典。
-
flax.traverse_util.unflatten_dict(xs, sep=None)[源代码]
取消展平字典。
请参阅 flatten_dict
示例
>>> flat_xs = {
... ('foo',): 1,
... ('bar', 'a'): 2,
... }
>>> xs = unflatten_dict(flat_xs)
>>> xs
{'foo': 1, 'bar': {'a': 2}}
- 参数
-
- 返回
嵌套的字典。
-
flax.traverse_util.path_aware_map(f, nested_dict)[源代码]
一个映射函数,它在考虑每个叶子路径的情况下对嵌套字典结构进行操作。
示例
>>> import jax.numpy as jnp
>>> from flax import traverse_util
>>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}}
>>> f = lambda path, x: x + 5 if 'x' in path else -x
>>> traverse_util.path_aware_map(f, params)
{'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
- 参数
-
- 返回
一个新的嵌套字典结构,其中包含映射的值。
模型参数遍历
-
class flax.traverse_util.ModelParamTraversal(*args, **kwargs)[源]
使用名称过滤器选择模型参数。
此遍历操作于参数的嵌套字典,并基于 filter_fn 参数选择子集。
有关如何使用 ModelParamTraversal 通过特定的优化器更新参数树的子集,请参见 flax.optim.MultiOptimizer。
-
__init__(filter_fn)[源]
构造一个新的 ModelParamTraversal。
- 参数
filter_fn – 一个函数,它接受参数的完整名称及其值,并返回是否应选择此参数。参数的名称由模块层次结构和参数名称确定(例如:’/module/sub_module/parameter_name’)。