Skip to content

utils ¤

circulax utilities.

Functions:

Name Description
update_group_params

Updates a parameter for ALL instances in a component group.

update_params_dict

Updates a parameter for a specific instance within a component group.

update_group_params ¤

update_group_params(
    groups_dict: dict, group_name: str, param_key: str, new_value: float
) -> dict[str, ComponentGroup]

Updates a parameter for ALL instances in a component group.

Source code in circulax/utils.py
def update_group_params(
    groups_dict: dict, group_name: str, param_key: str, new_value: float
) -> dict[str, "ComponentGroup"]:
    """Updates a parameter for ALL instances in a component group."""
    g = groups_dict[group_name]

    # Handle Equinox Component (Batched)
    batched_comp = g.params
    current_val = getattr(batched_comp, param_key)

    new_vals = jnp.full_like(current_val, new_value)

    new_batched_comp = eqx.tree_at(
        lambda c: getattr(c, param_key), batched_comp, new_vals
    )
    new_g = eqx.tree_at(lambda g: g.params, g, new_batched_comp)

    return {**groups_dict, group_name: new_g}

update_params_dict ¤

update_params_dict(
    groups_dict: dict,
    group_name: str,
    instance_name: str,
    param_key: str,
    new_value: float,
) -> dict[str, ComponentGroup]

Updates a parameter for a specific instance within a component group.

Source code in circulax/utils.py
def update_params_dict(
    groups_dict: dict,
    group_name: str,
    instance_name: str,
    param_key: str,
    new_value: float,
) -> dict[str, "ComponentGroup"]:
    """Updates a parameter for a specific instance within a component group."""
    g = groups_dict[group_name]

    instance_idx = g.index_map[instance_name]

    # Handle Equinox Component (Batched)
    batched_comp = g.params
    current_val = getattr(batched_comp, param_key)
    new_vals = current_val.at[instance_idx].set(new_value)

    new_batched_comp = eqx.tree_at(
        lambda c: getattr(c, param_key), batched_comp, new_vals
    )
    new_g = eqx.tree_at(lambda g: g.params, g, new_batched_comp)

    # Return new dict (JAX helper to copy-and-modify dicts)
    return {**groups_dict, group_name: new_g}