def find_op_nodes(
    op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
    if isinstance(op, OpOverloadPacket):
        for overload in op.overloads():
            overload_op = getattr(op, overload)
            yield from find_op_nodes(overload_op, graph)
        return
    assert isinstance(op, OpOverload)
    if not op._schema.is_mutable:
        yield from graph.find_nodes(op="call_function", target=op)
    for n in graph.find_nodes(op="call_function", target=auto_functionalized):
        if n.args[0] == op:
            yield n