Skip to content

vllm.compilation.partition_rules

logger module-attribute

logger = init_logger(__name__)

inductor_partition_rule_context

inductor_partition_rule_context(splitting_ops: list[str])

Context manager to temporarily register Inductor partition rules.

Registers custom partition rules for specified operators, forcing the Inductor scheduler to partition the graph at these operators. The rules are automatically restored to their previous state on exit.

Parameters:

Name Type Description Default
splitting_ops list[str]

List of operator names to partition on.

required
Source code in vllm/compilation/partition_rules.py
@contextlib.contextmanager
def inductor_partition_rule_context(splitting_ops: list[str]):
    """Context manager to temporarily register Inductor partition rules.

    Registers custom partition rules for specified operators, forcing the
    Inductor scheduler to partition the graph at these operators. The rules
    are automatically restored to their previous state on exit.

    Args:
        splitting_ops: List of operator names to partition on.
    """
    if not splitting_ops:
        logger.debug("No partition ops provided; skipping rule registration.")
        yield
        return

    # Save current state before registering

    saved_splitting_ops: list[str] = list(
        torch._inductor.config.custom_should_partition_ops
    )
    torch._inductor.config.custom_should_partition_ops = splitting_ops

    logger.debug(
        "Registered inductor partition rules for %d operators", len(splitting_ops)
    )

    try:
        yield
    finally:
        # Clear and restore previous state
        torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
        logger.debug("Restored previous partition rules state.")

should_split

should_split(node: Node, splitting_ops: list[str]) -> bool

Check if a node should be split for dynamo graph partition. It operates on dynamo graph, so the node.target can be anything. We need to check and split only on OpOverload and OpOverloadPacket.

Source code in vllm/compilation/partition_rules.py
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
    """
    Check if a node should be split for dynamo graph partition.
    It operates on dynamo graph, so the node.target can be anything.
    We need to check and split only on OpOverload and OpOverloadPacket.
    """

    if node.op != "call_function":
        return False

    target = node.target

    if isinstance(target, torch._ops.OpOverloadPacket):
        # Example: "aten::add"
        return target._qualified_op_name in splitting_ops

    if isinstance(target, torch._ops.OpOverload):
        # Example: "aten::add"
        packet_name = target.name()

        # Example: "aten::add.default"
        op_overload_name = f"{packet_name}.{target._overloadname}"
        return op_overload_name in splitting_ops or packet_name in splitting_ops

    return False