Skip to content

vllm.compilation.matcher_utils

FLASHINFER_ROTARY_OP module-attribute

FLASHINFER_ROTARY_OP = default

QUANT_OPS module-attribute

QUANT_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: default,
    kFp8DynamicTensorSym: default,
    kFp8DynamicTokenSym: default,
}

RMS_ADD_OP module-attribute

RMS_ADD_OP = default

RMS_OP module-attribute

RMS_OP = default

ROTARY_OP module-attribute

ROTARY_OP = default

SILU_MUL_OP module-attribute

SILU_MUL_OP = default

MatcherCustomOp

Bases: ABC

Source code in vllm/compilation/matcher_utils.py
class MatcherCustomOp(ABC):
    def __init__(self, enabled: bool):
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
        self.device = config.device_config.device if config.device_config else None

        self.enabled = enabled
        self.forward = self.forward_custom if enabled else self.forward_native

    @abstractmethod
    def forward_custom(self, *args, **kws):
        pass

    @abstractmethod
    def forward_native(self, *args, **kws):
        pass

    def __call__(self, *args, **kws):
        return self.forward(*args, **kws)

    def empty(self, *args, **kws):
        return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)

    def empty_int64(self, *args, **kws):
        return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)

    def empty_f32(self, *args, **kws):
        return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)

    def inputs(self) -> list[torch.Tensor]:
        """Utility for inputs to the pattern"""
        raise NotImplementedError

device instance-attribute

device = device if device_config else None

enabled instance-attribute

enabled = enabled

forward instance-attribute

forward = forward_custom if enabled else forward_native

model_dtype instance-attribute

model_dtype = dtype if model_config else None

__call__

__call__(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
def __call__(self, *args, **kws):
    return self.forward(*args, **kws)

__init__

__init__(enabled: bool)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, enabled: bool):
    config = get_current_vllm_config()
    self.model_dtype = config.model_config.dtype if config.model_config else None
    self.device = config.device_config.device if config.device_config else None

    self.enabled = enabled
    self.forward = self.forward_custom if enabled else self.forward_native

empty

empty(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
def empty(self, *args, **kws):
    return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)

empty_f32

empty_f32(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
def empty_f32(self, *args, **kws):
    return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)

empty_int64

empty_int64(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
def empty_int64(self, *args, **kws):
    return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)

forward_custom abstractmethod

forward_custom(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
@abstractmethod
def forward_custom(self, *args, **kws):
    pass

forward_native abstractmethod

forward_native(*args, **kws)
Source code in vllm/compilation/matcher_utils.py
@abstractmethod
def forward_native(self, *args, **kws):
    pass

inputs

inputs() -> list[Tensor]

Utility for inputs to the pattern

Source code in vllm/compilation/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    """Utility for inputs to the pattern"""
    raise NotImplementedError

MatcherFusedAddRMSNorm

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherFusedAddRMSNorm(MatcherCustomOp):
    def __init__(self, epsilon: float, enabled: bool | None = None):
        if enabled is None:
            enabled = RMSNorm.enabled()

        super().__init__(enabled)
        self.epsilon = epsilon

    def inputs(self):
        input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
        weight = self.empty(16)
        residual = self.empty(5, 16)
        return [input, weight, residual]

    def forward_custom(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        _, result, residual = auto_functionalized(
            RMS_ADD_OP,
            input=input,
            residual=residual,
            weight=weight,
            epsilon=self.epsilon,
        )

        return result, residual

    def forward_native(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return RMSNorm.forward_static(
            input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
        )

epsilon instance-attribute

epsilon = epsilon

__init__

__init__(epsilon: float, enabled: bool | None = None)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, epsilon: float, enabled: bool | None = None):
    if enabled is None:
        enabled = RMSNorm.enabled()

    super().__init__(enabled)
    self.epsilon = epsilon

forward_custom

forward_custom(
    input: Tensor, weight: Tensor, residual: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
    residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    _, result, residual = auto_functionalized(
        RMS_ADD_OP,
        input=input,
        residual=residual,
        weight=weight,
        epsilon=self.epsilon,
    )

    return result, residual

forward_native

forward_native(
    input: Tensor, weight: Tensor, residual: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
    residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    return RMSNorm.forward_static(
        input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
    )

inputs

inputs()
Source code in vllm/compilation/matcher_utils.py
def inputs(self):
    input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
    weight = self.empty(16)
    residual = self.empty(5, 16)
    return [input, weight, residual]

MatcherQuantFP8

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherQuantFP8(MatcherCustomOp):
    def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
        if enabled is None:
            enabled = QuantFP8.enabled()

        super().__init__(enabled)
        self.quant_key = quant_key
        assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
        self.QUANT_OP = QUANT_OPS[quant_key]

        assert quant_key.dtype == current_platform.fp8_dtype(), (
            "Only QuantFP8 supported by"
        )
        assert quant_key.scale2 is None
        self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)

    def forward_custom(
        self,
        input: torch.Tensor,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        result = torch.empty(
            input.shape, device=input.device, dtype=self.quant_key.dtype
        )

        if self.quant_key.scale.static:
            assert scale is not None
            _, result = auto_functionalized(
                self.QUANT_OP, result=result, input=input, scale=scale
            )
            return result, scale
        else:
            assert scale is None
            scale = self.make_scale(input)
            _, result, scale = auto_functionalized(
                self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
            )
            return result, scale

    def forward_native(
        self,
        input: torch.Tensor,
        scale: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.quant_fp8(input, scale)

    def make_scale(self, input: torch.Tensor):
        normalized_group_shape = _normalize_quant_group_shape(
            input, self.quant_key.scale.group_shape
        )
        scale_shape = (
            input.shape[0] // normalized_group_shape[0],
            input.shape[1] // normalized_group_shape[1],
        )

        return torch.empty(scale_shape, device=input.device, dtype=torch.float32)

    def inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 16)
        if self.quant_key.scale.static:
            return [input, self.empty_f32(1, 1)]

        return [input]

QUANT_OP instance-attribute

QUANT_OP = QUANT_OPS[quant_key]

quant_fp8 instance-attribute

quant_fp8 = QuantFP8(static, group_shape)

quant_key instance-attribute

quant_key = quant_key

__init__

__init__(quant_key: QuantKey, enabled: bool | None = None)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
    if enabled is None:
        enabled = QuantFP8.enabled()

    super().__init__(enabled)
    self.quant_key = quant_key
    assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
    self.QUANT_OP = QUANT_OPS[quant_key]

    assert quant_key.dtype == current_platform.fp8_dtype(), (
        "Only QuantFP8 supported by"
    )
    assert quant_key.scale2 is None
    self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)

forward_custom

forward_custom(
    input: Tensor, scale: Tensor | None = None
) -> tuple[Tensor, Tensor]
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    input: torch.Tensor,
    scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    result = torch.empty(
        input.shape, device=input.device, dtype=self.quant_key.dtype
    )

    if self.quant_key.scale.static:
        assert scale is not None
        _, result = auto_functionalized(
            self.QUANT_OP, result=result, input=input, scale=scale
        )
        return result, scale
    else:
        assert scale is None
        scale = self.make_scale(input)
        _, result, scale = auto_functionalized(
            self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
        )
        return result, scale

forward_native

forward_native(
    input: Tensor, scale: Tensor | None = None
) -> tuple[Tensor, Tensor]
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    input: torch.Tensor,
    scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    return self.quant_fp8(input, scale)

inputs

inputs() -> list[Tensor]
Source code in vllm/compilation/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    input = self.empty(5, 16)
    if self.quant_key.scale.static:
        return [input, self.empty_f32(1, 1)]

    return [input]

make_scale

make_scale(input: Tensor)
Source code in vllm/compilation/matcher_utils.py
def make_scale(self, input: torch.Tensor):
    normalized_group_shape = _normalize_quant_group_shape(
        input, self.quant_key.scale.group_shape
    )
    scale_shape = (
        input.shape[0] // normalized_group_shape[0],
        input.shape[1] // normalized_group_shape[1],
    )

    return torch.empty(scale_shape, device=input.device, dtype=torch.float32)

MatcherRMSNorm

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherRMSNorm(MatcherCustomOp):
    def __init__(self, epsilon: float, enabled: bool | None = None):
        if enabled is None:
            enabled = RMSNorm.enabled()

        super().__init__(enabled)
        self.epsilon = epsilon

    def inputs(self):
        input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
        weight = self.empty(16)
        return [input, weight]

    def forward_custom(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        result = torch.empty_like(input)
        # TODO: support non-contiguous input for RMSNorm and remove this
        input_contiguous = input.contiguous()
        _, result = auto_functionalized(
            RMS_OP,
            result=result,
            input=input_contiguous,
            weight=weight,
            epsilon=self.epsilon,
        )

        return result

    def forward_native(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        return RMSNorm.forward_static(
            input, self.epsilon, input.size(-1), self.model_dtype, weight
        )

epsilon instance-attribute

epsilon = epsilon

__init__

__init__(epsilon: float, enabled: bool | None = None)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, epsilon: float, enabled: bool | None = None):
    if enabled is None:
        enabled = RMSNorm.enabled()

    super().__init__(enabled)
    self.epsilon = epsilon

forward_custom

forward_custom(input: Tensor, weight: Tensor) -> Tensor
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    result = torch.empty_like(input)
    # TODO: support non-contiguous input for RMSNorm and remove this
    input_contiguous = input.contiguous()
    _, result = auto_functionalized(
        RMS_OP,
        result=result,
        input=input_contiguous,
        weight=weight,
        epsilon=self.epsilon,
    )

    return result

forward_native

forward_native(input: Tensor, weight: Tensor) -> Tensor
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    return RMSNorm.forward_static(
        input, self.epsilon, input.size(-1), self.model_dtype, weight
    )

inputs

inputs()
Source code in vllm/compilation/matcher_utils.py
def inputs(self):
    input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
    weight = self.empty(16)
    return [input, weight]

MatcherRotaryEmbedding

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherRotaryEmbedding(MatcherCustomOp):
    def __init__(
        self,
        is_neox: bool,
        head_size: int,
        num_heads: int,
        num_kv_heads: int,
        use_flashinfer: bool = False,
        enabled: bool | None = None,
    ) -> None:
        if enabled is None:
            enabled = RotaryEmbedding.enabled()

        super().__init__(enabled)
        self.is_neox = is_neox
        self.head_size = head_size
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.q_size = self.num_heads * self.head_size
        self.kv_size = self.num_kv_heads * self.head_size
        self.rotary_dim = head_size
        if use_flashinfer:
            self.rotary_op = FLASHINFER_ROTARY_OP
        else:
            self.rotary_op = ROTARY_OP

    def inputs(self) -> list[torch.Tensor]:
        positions = self.empty_int64(5)
        query = self.empty(5, self.q_size)
        key = self.empty(5, self.kv_size)
        cos_sin_cache = self.empty(4096, self.rotary_dim)
        return [positions, query, key, cos_sin_cache]

    def forward_custom(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None,
        cos_sin_cache: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        result = auto_functionalized(
            self.rotary_op,
            positions=positions,
            query=query,
            key=key,
            head_size=self.head_size,
            cos_sin_cache=cos_sin_cache,
            is_neox=self.is_neox,
        )
        query_out = result[1]
        key_out = result[2] if len(result) > 2 else None
        return query_out, key_out

    def forward_native(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None,
        cos_sin_cache: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        return RotaryEmbedding.forward_static(
            positions,
            query,
            key,
            self.head_size,
            self.rotary_dim,
            cos_sin_cache,
            self.is_neox,
        )

head_size instance-attribute

head_size = head_size

is_neox instance-attribute

is_neox = is_neox

kv_size instance-attribute

kv_size = num_kv_heads * head_size

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

q_size instance-attribute

q_size = num_heads * head_size

rotary_dim instance-attribute

rotary_dim = head_size

rotary_op instance-attribute

__init__

__init__(
    is_neox: bool,
    head_size: int,
    num_heads: int,
    num_kv_heads: int,
    use_flashinfer: bool = False,
    enabled: bool | None = None,
) -> None
Source code in vllm/compilation/matcher_utils.py
def __init__(
    self,
    is_neox: bool,
    head_size: int,
    num_heads: int,
    num_kv_heads: int,
    use_flashinfer: bool = False,
    enabled: bool | None = None,
) -> None:
    if enabled is None:
        enabled = RotaryEmbedding.enabled()

    super().__init__(enabled)
    self.is_neox = is_neox
    self.head_size = head_size
    self.num_heads = num_heads
    self.num_kv_heads = num_kv_heads
    self.q_size = self.num_heads * self.head_size
    self.kv_size = self.num_kv_heads * self.head_size
    self.rotary_dim = head_size
    if use_flashinfer:
        self.rotary_op = FLASHINFER_ROTARY_OP
    else:
        self.rotary_op = ROTARY_OP

forward_custom

forward_custom(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None,
    cos_sin_cache: Tensor,
) -> tuple[Tensor, Tensor | None]
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None,
    cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    result = auto_functionalized(
        self.rotary_op,
        positions=positions,
        query=query,
        key=key,
        head_size=self.head_size,
        cos_sin_cache=cos_sin_cache,
        is_neox=self.is_neox,
    )
    query_out = result[1]
    key_out = result[2] if len(result) > 2 else None
    return query_out, key_out

forward_native

forward_native(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None,
    cos_sin_cache: Tensor,
) -> tuple[Tensor, Tensor | None]
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None,
    cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    return RotaryEmbedding.forward_static(
        positions,
        query,
        key,
        self.head_size,
        self.rotary_dim,
        cos_sin_cache,
        self.is_neox,
    )

inputs

inputs() -> list[Tensor]
Source code in vllm/compilation/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    positions = self.empty_int64(5)
    query = self.empty(5, self.q_size)
    key = self.empty(5, self.kv_size)
    cos_sin_cache = self.empty(4096, self.rotary_dim)
    return [positions, query, key, cos_sin_cache]

MatcherSiluAndMul

Bases: MatcherCustomOp

Source code in vllm/compilation/matcher_utils.py
class MatcherSiluAndMul(MatcherCustomOp):
    def __init__(self, enabled: bool | None = None):
        if enabled is None:
            enabled = SiluAndMul.enabled()
        super().__init__(enabled)

    def inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 4)
        return [input]

    def forward_custom(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
        return result[1]

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return SiluAndMul.forward_native(x)

__init__

__init__(enabled: bool | None = None)
Source code in vllm/compilation/matcher_utils.py
def __init__(self, enabled: bool | None = None):
    if enabled is None:
        enabled = SiluAndMul.enabled()
    super().__init__(enabled)

forward_custom

forward_custom(x: Tensor) -> Tensor
Source code in vllm/compilation/matcher_utils.py
def forward_custom(
    self,
    x: torch.Tensor,
) -> torch.Tensor:
    d = x.shape[-1] // 2
    output_shape = x.shape[:-1] + (d,)
    out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
    result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
    return result[1]

forward_native

forward_native(x: Tensor) -> Tensor
Source code in vllm/compilation/matcher_utils.py
def forward_native(
    self,
    x: torch.Tensor,
) -> torch.Tensor:
    return SiluAndMul.forward_native(x)

inputs

inputs() -> list[Tensor]
Source code in vllm/compilation/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    input = self.empty(5, 4)
    return [input]