Skip to content

JVP treatment of mask as a leaf vs pytree #29

@rdyro

Description

@rdyro

The mask in jvp might not match the pytree structure of grad_y.

Unfortunately, I don't have an open-source repro at the moment, but these lines assume pytrees of a similar structure, however the mask is a leaf, but the gradients are a pytree

Is that correct?

I believe you recently addressed a similar issue here:
5150e33#diff-21e634aa62155f577c8e87e1b851189b4791db79bdb2593cc957ca86e8cde5ccL328

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions