Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration

Modules:

Name Description
multi_process_adapter
utils
vllm_v1_adapter

__all__ module-attribute

__all__ = [
    "vllm_v1_adapter",
    "multi_process_adapter",
    "LMCacheMPSchedulerAdapter",
    "LMCacheMPWorkerAdapter",
    "LoadStoreOp",
]

LMCacheMPSchedulerAdapter

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
class LMCacheMPSchedulerAdapter:
    def __init__(
        self,
        server_url: str,
        context: zmq.Context,
        model_name: str,
        world_size: int,
        kv_rank: int,
        vllm_block_size: int,
    ):
        """
        Args:
            server_url: The server URL for the LMCache message queue
            context: The ZMQ context

            model_name: The model name used for LMCache keys
            world_size: The world size used for LMCache keys
            kv_rank: The kv rank used for LMCache keys
            vllm_block_size: The block size used in vLLM
        """
        self.mq_client = MessageQueueClient(server_url, context)

        # Request futures
        self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}

        self.model_name = model_name
        self.world_size = world_size
        self.worker_id = kv_rank

        # Read chunk size from lmcache
        self.chunk_size = get_lmcache_chunk_size(self.mq_client)
        assert self.chunk_size % vllm_block_size == 0, (
            "LMCache chunk size should be a multiple of vLLM block size"
        )
        self.blocks_in_chunk = self.chunk_size // vllm_block_size

    @_lmcache_nvtx_annotate
    def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]):
        if request_id in self.lookup_futures:
            # Skip if there is already a lookup request
            return

        s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
        keys = [self._create_key(block_hash) for block_hash in s]
        future = send_lmcache_request(
            self.mq_client,
            RequestType.LOOKUP,
            [keys, True],
        )
        self.lookup_futures[request_id] = future

    @_lmcache_nvtx_annotate
    def check_lookup_result(self, request_id: str) -> int | None:
        assert request_id in self.lookup_futures, (
            f"Lookup request for request_id={request_id} has not been submitted"
        )

        future = self.lookup_futures[request_id]
        if not future.query():
            return None

        result = future.result()
        num_chunks = sum(result)
        return num_chunks * self.chunk_size

    def num_blocks_per_chunk(self) -> int:
        """
        Returns:
            The number of vllm blocks in a LMCache data chunk
        """
        return self.blocks_in_chunk

    # Helper functions
    def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
        """Convert a block hash to an IPC cache engine key"""
        return IPCCacheEngineKey(
            model_name=self.model_name,
            world_size=self.world_size,
            worker_id=self.worker_id,
            chunk_hash=block_hash,
        )

blocks_in_chunk instance-attribute

blocks_in_chunk = chunk_size // vllm_block_size

chunk_size instance-attribute

chunk_size = get_lmcache_chunk_size(mq_client)

lookup_futures instance-attribute

lookup_futures: dict[
    str, MessagingFuture[LookupResult]
] = {}

model_name instance-attribute

model_name = model_name

mq_client instance-attribute

mq_client = MessageQueueClient(server_url, context)

worker_id instance-attribute

worker_id = kv_rank

world_size instance-attribute

world_size = world_size

__init__

__init__(
    server_url: str,
    context: Context,
    model_name: str,
    world_size: int,
    kv_rank: int,
    vllm_block_size: int,
)

Parameters:

Name Type Description Default
server_url str

The server URL for the LMCache message queue

required
context Context

The ZMQ context

required
model_name str

The model name used for LMCache keys

required
world_size int

The world size used for LMCache keys

required
kv_rank int

The kv rank used for LMCache keys

required
vllm_block_size int

The block size used in vLLM

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def __init__(
    self,
    server_url: str,
    context: zmq.Context,
    model_name: str,
    world_size: int,
    kv_rank: int,
    vllm_block_size: int,
):
    """
    Args:
        server_url: The server URL for the LMCache message queue
        context: The ZMQ context

        model_name: The model name used for LMCache keys
        world_size: The world size used for LMCache keys
        kv_rank: The kv rank used for LMCache keys
        vllm_block_size: The block size used in vLLM
    """
    self.mq_client = MessageQueueClient(server_url, context)

    # Request futures
    self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}

    self.model_name = model_name
    self.world_size = world_size
    self.worker_id = kv_rank

    # Read chunk size from lmcache
    self.chunk_size = get_lmcache_chunk_size(self.mq_client)
    assert self.chunk_size % vllm_block_size == 0, (
        "LMCache chunk size should be a multiple of vLLM block size"
    )
    self.blocks_in_chunk = self.chunk_size // vllm_block_size

_create_key

_create_key(block_hash: bytes) -> IPCCacheEngineKey

Convert a block hash to an IPC cache engine key

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
    """Convert a block hash to an IPC cache engine key"""
    return IPCCacheEngineKey(
        model_name=self.model_name,
        world_size=self.world_size,
        worker_id=self.worker_id,
        chunk_hash=block_hash,
    )

check_lookup_result

check_lookup_result(request_id: str) -> int | None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def check_lookup_result(self, request_id: str) -> int | None:
    assert request_id in self.lookup_futures, (
        f"Lookup request for request_id={request_id} has not been submitted"
    )

    future = self.lookup_futures[request_id]
    if not future.query():
        return None

    result = future.result()
    num_chunks = sum(result)
    return num_chunks * self.chunk_size

maybe_submit_lookup_request

maybe_submit_lookup_request(
    request_id: str, block_hashes: list[bytes]
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]):
    if request_id in self.lookup_futures:
        # Skip if there is already a lookup request
        return

    s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
    keys = [self._create_key(block_hash) for block_hash in s]
    future = send_lmcache_request(
        self.mq_client,
        RequestType.LOOKUP,
        [keys, True],
    )
    self.lookup_futures[request_id] = future

num_blocks_per_chunk

num_blocks_per_chunk() -> int

Returns:

Type Description
int

The number of vllm blocks in a LMCache data chunk

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def num_blocks_per_chunk(self) -> int:
    """
    Returns:
        The number of vllm blocks in a LMCache data chunk
    """
    return self.blocks_in_chunk

LMCacheMPWorkerAdapter

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
class LMCacheMPWorkerAdapter:
    def __init__(
        self,
        server_url: str,
        context: zmq.Context,
        model_name: str,
        world_size: int,
        kv_rank: int,
        vllm_block_size: int,
    ):
        self.mq_client = MessageQueueClient(server_url, context)

        # Instance id for GPU worker
        self.instance_id = os.getpid()

        # Registered kv caches from vLLM
        self.kv_caches: dict[str, torch.Tensor] = {}

        # Request futures
        # request_id -> (future, other merged requests)
        self.store_futures: dict[
            str, tuple[MessagingFuture[StoreResult], list[str]]
        ] = {}
        self.retrieve_futures: dict[
            str, tuple[MessagingFuture[RetrieveResult], list[str]]
        ] = {}

        self.finished_stores: set[str] = set()
        self.previously_finished: set[str] = set()

        self.model_name = model_name
        self.world_size = world_size
        self.worker_id = kv_rank

        # Read chunk size from lmcache
        chunk_size = get_lmcache_chunk_size(self.mq_client)
        assert chunk_size % vllm_block_size == 0, (
            "LMCache chunk size should be a multiple of vLLM block size"
        )
        self.blocks_in_chunk = chunk_size // vllm_block_size

    def register_kv_caches(self, kv_caches: dict[str, KVCache]):
        # Register kv cache and send the request
        self.kv_caches = kv_caches
        logger.info("Registering kv caches")
        future = send_lmcache_request(
            self.mq_client,
            RequestType.REGISTER_KV_CACHE,
            [self.instance_id, wrap_kv_caches(kv_caches)],
        )
        future.result()

    @_lmcache_nvtx_annotate
    def submit_store_request(
        self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
    ):
        keys = self._block_hashes_to_keys(op.block_hashes)
        future = send_lmcache_request(
            self.mq_client,
            RequestType.STORE,
            [keys, self.instance_id, op.block_ids, event.ipc_handle()],
        ).to_cuda_future()
        self.store_futures[request_id] = (future, [])

    @_lmcache_nvtx_annotate
    def submit_retrieve_request(
        self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
    ):
        keys = self._block_hashes_to_keys(op.block_hashes)
        future = send_lmcache_request(
            self.mq_client,
            RequestType.RETRIEVE,
            [keys, self.instance_id, op.block_ids, event.ipc_handle()],
        ).to_cuda_future()
        self.retrieve_futures[request_id] = (future, [])

    @_lmcache_nvtx_annotate
    def batched_submit_store_requests(
        self,
        request_ids: list[str],
        ops: list[LoadStoreOp],
        event: torch.cuda.Event,
    ):
        keys = []
        block_ids = []
        for op in ops:
            keys.extend(self._block_hashes_to_keys(op.block_hashes))
            block_ids.extend(op.block_ids)
        future = send_lmcache_request(
            self.mq_client,
            RequestType.STORE,
            [keys, self.instance_id, block_ids, event.ipc_handle()],
        ).to_cuda_future()
        self.store_futures[request_ids[0]] = (future, request_ids[1:])

    @_lmcache_nvtx_annotate
    def batched_submit_retrieve_requests(
        self,
        request_ids: list[str],
        ops: list[LoadStoreOp],
        event: torch.cuda.Event,
    ):
        keys = []
        block_ids = []
        for op in ops:
            keys.extend(self._block_hashes_to_keys(op.block_hashes))
            block_ids.extend(op.block_ids)
        future = send_lmcache_request(
            self.mq_client,
            RequestType.RETRIEVE,
            [keys, self.instance_id, block_ids, event.ipc_handle()],
        ).to_cuda_future()
        self.retrieve_futures[request_ids[0]] = (future, request_ids[1:])

    @_lmcache_nvtx_annotate
    def get_finished(
        self, finished_req_ids: set[str]
    ) -> tuple[set[str] | None, set[str] | None]:
        finished_stores = set()
        finished_retrieves = set()
        for request_id, (future, other_reqs) in self.store_futures.items():
            if not future.query():
                continue

            result = future.result()
            finished_stores.add(request_id)
            finished_stores.update(other_reqs)

            if not result:
                # TODO: add error handling here
                logger.error(
                    "Something went wrong when processing the "
                    "store request for request_id=%s",
                    request_id,
                )

        for request_id, (future, other_reqs) in self.retrieve_futures.items():
            if not future.query():
                continue

            result = future.result()
            finished_retrieves.add(request_id)
            finished_retrieves.update(other_reqs)

            if not all(result):
                # TODO: add error handing here
                logger.error(
                    "Something went wrong when processing the "
                    "retrieve request for request_id=%s, result=%s",
                    request_id,
                    result,
                )
            logger.info("Retrieve request for request_id=%s finished", request_id)

        # Remove the finished requests from the tracking dicts
        for request_id in finished_stores:
            self.store_futures.pop(request_id, None)
        for request_id in finished_retrieves:
            self.retrieve_futures.pop(request_id, None)

        # Update the internal states
        self.finished_stores.update(finished_stores)

        ret_stores = set()
        for req_id in finished_req_ids:
            if req_id in self.finished_stores or req_id in self.store_futures:
                self.previously_finished.add(req_id)
            else:
                ret_stores.add(req_id)

        # Calculate the final finished stores
        ret_stores.update(self._update_and_get_finished_store())

        return ret_stores, finished_retrieves

    def num_blocks_per_chunk(self) -> int:
        """
        Returns:
            The number of vllm blocks in a LMCache data chunk
        """
        return self.blocks_in_chunk

    def shutdown(self):
        # Unregister kv cache
        logger.info("Unregistering kv caches")
        send_lmcache_request(
            self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
        ).result()

        self.mq_client.close()

    # Helper functions
    def _update_and_get_finished_store(
        self,
    ) -> set[str]:
        """Converge the internal states about finished stores
        and returns the 'safe finished store request ids' back
        """
        safe_finished_s = self.finished_stores.intersection(self.previously_finished)
        self.finished_stores.difference_update(self.previously_finished)
        self.previously_finished.difference_update(safe_finished_s)

        return safe_finished_s

    def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
        """Convert a block hash to an IPC cache engine key"""
        return IPCCacheEngineKey(
            model_name=self.model_name,
            world_size=self.world_size,
            worker_id=self.worker_id,
            chunk_hash=block_hash,
        )

    def _block_hashes_to_keys(
        self, block_hashes: list[bytes]
    ) -> list[IPCCacheEngineKey]:
        """Convert block hashes to IPC cache engine keys"""
        s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
        return [self._create_key(block_hash) for block_hash in s]

blocks_in_chunk instance-attribute

blocks_in_chunk = chunk_size // vllm_block_size

finished_stores instance-attribute

finished_stores: set[str] = set()

instance_id instance-attribute

instance_id = getpid()

kv_caches instance-attribute

kv_caches: dict[str, Tensor] = {}

model_name instance-attribute

model_name = model_name

mq_client instance-attribute

mq_client = MessageQueueClient(server_url, context)

previously_finished instance-attribute

previously_finished: set[str] = set()

retrieve_futures instance-attribute

retrieve_futures: dict[
    str, tuple[MessagingFuture[RetrieveResult], list[str]]
] = {}

store_futures instance-attribute

store_futures: dict[
    str, tuple[MessagingFuture[StoreResult], list[str]]
] = {}

worker_id instance-attribute

worker_id = kv_rank

world_size instance-attribute

world_size = world_size

__init__

__init__(
    server_url: str,
    context: Context,
    model_name: str,
    world_size: int,
    kv_rank: int,
    vllm_block_size: int,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def __init__(
    self,
    server_url: str,
    context: zmq.Context,
    model_name: str,
    world_size: int,
    kv_rank: int,
    vllm_block_size: int,
):
    self.mq_client = MessageQueueClient(server_url, context)

    # Instance id for GPU worker
    self.instance_id = os.getpid()

    # Registered kv caches from vLLM
    self.kv_caches: dict[str, torch.Tensor] = {}

    # Request futures
    # request_id -> (future, other merged requests)
    self.store_futures: dict[
        str, tuple[MessagingFuture[StoreResult], list[str]]
    ] = {}
    self.retrieve_futures: dict[
        str, tuple[MessagingFuture[RetrieveResult], list[str]]
    ] = {}

    self.finished_stores: set[str] = set()
    self.previously_finished: set[str] = set()

    self.model_name = model_name
    self.world_size = world_size
    self.worker_id = kv_rank

    # Read chunk size from lmcache
    chunk_size = get_lmcache_chunk_size(self.mq_client)
    assert chunk_size % vllm_block_size == 0, (
        "LMCache chunk size should be a multiple of vLLM block size"
    )
    self.blocks_in_chunk = chunk_size // vllm_block_size

_block_hashes_to_keys

_block_hashes_to_keys(
    block_hashes: list[bytes],
) -> list[IPCCacheEngineKey]

Convert block hashes to IPC cache engine keys

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _block_hashes_to_keys(
    self, block_hashes: list[bytes]
) -> list[IPCCacheEngineKey]:
    """Convert block hashes to IPC cache engine keys"""
    s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
    return [self._create_key(block_hash) for block_hash in s]

_create_key

_create_key(block_hash: bytes) -> IPCCacheEngineKey

Convert a block hash to an IPC cache engine key

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
    """Convert a block hash to an IPC cache engine key"""
    return IPCCacheEngineKey(
        model_name=self.model_name,
        world_size=self.world_size,
        worker_id=self.worker_id,
        chunk_hash=block_hash,
    )

_update_and_get_finished_store

_update_and_get_finished_store() -> set[str]

Converge the internal states about finished stores and returns the 'safe finished store request ids' back

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _update_and_get_finished_store(
    self,
) -> set[str]:
    """Converge the internal states about finished stores
    and returns the 'safe finished store request ids' back
    """
    safe_finished_s = self.finished_stores.intersection(self.previously_finished)
    self.finished_stores.difference_update(self.previously_finished)
    self.previously_finished.difference_update(safe_finished_s)

    return safe_finished_s

batched_submit_retrieve_requests

batched_submit_retrieve_requests(
    request_ids: list[str],
    ops: list[LoadStoreOp],
    event: Event,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def batched_submit_retrieve_requests(
    self,
    request_ids: list[str],
    ops: list[LoadStoreOp],
    event: torch.cuda.Event,
):
    keys = []
    block_ids = []
    for op in ops:
        keys.extend(self._block_hashes_to_keys(op.block_hashes))
        block_ids.extend(op.block_ids)
    future = send_lmcache_request(
        self.mq_client,
        RequestType.RETRIEVE,
        [keys, self.instance_id, block_ids, event.ipc_handle()],
    ).to_cuda_future()
    self.retrieve_futures[request_ids[0]] = (future, request_ids[1:])

batched_submit_store_requests

batched_submit_store_requests(
    request_ids: list[str],
    ops: list[LoadStoreOp],
    event: Event,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def batched_submit_store_requests(
    self,
    request_ids: list[str],
    ops: list[LoadStoreOp],
    event: torch.cuda.Event,
):
    keys = []
    block_ids = []
    for op in ops:
        keys.extend(self._block_hashes_to_keys(op.block_hashes))
        block_ids.extend(op.block_ids)
    future = send_lmcache_request(
        self.mq_client,
        RequestType.STORE,
        [keys, self.instance_id, block_ids, event.ipc_handle()],
    ).to_cuda_future()
    self.store_futures[request_ids[0]] = (future, request_ids[1:])

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[set[str] | None, set[str] | None]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def get_finished(
    self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
    finished_stores = set()
    finished_retrieves = set()
    for request_id, (future, other_reqs) in self.store_futures.items():
        if not future.query():
            continue

        result = future.result()
        finished_stores.add(request_id)
        finished_stores.update(other_reqs)

        if not result:
            # TODO: add error handling here
            logger.error(
                "Something went wrong when processing the "
                "store request for request_id=%s",
                request_id,
            )

    for request_id, (future, other_reqs) in self.retrieve_futures.items():
        if not future.query():
            continue

        result = future.result()
        finished_retrieves.add(request_id)
        finished_retrieves.update(other_reqs)

        if not all(result):
            # TODO: add error handing here
            logger.error(
                "Something went wrong when processing the "
                "retrieve request for request_id=%s, result=%s",
                request_id,
                result,
            )
        logger.info("Retrieve request for request_id=%s finished", request_id)

    # Remove the finished requests from the tracking dicts
    for request_id in finished_stores:
        self.store_futures.pop(request_id, None)
    for request_id in finished_retrieves:
        self.retrieve_futures.pop(request_id, None)

    # Update the internal states
    self.finished_stores.update(finished_stores)

    ret_stores = set()
    for req_id in finished_req_ids:
        if req_id in self.finished_stores or req_id in self.store_futures:
            self.previously_finished.add(req_id)
        else:
            ret_stores.add(req_id)

    # Calculate the final finished stores
    ret_stores.update(self._update_and_get_finished_store())

    return ret_stores, finished_retrieves

num_blocks_per_chunk

num_blocks_per_chunk() -> int

Returns:

Type Description
int

The number of vllm blocks in a LMCache data chunk

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def num_blocks_per_chunk(self) -> int:
    """
    Returns:
        The number of vllm blocks in a LMCache data chunk
    """
    return self.blocks_in_chunk

register_kv_caches

register_kv_caches(kv_caches: dict[str, KVCache])
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def register_kv_caches(self, kv_caches: dict[str, KVCache]):
    # Register kv cache and send the request
    self.kv_caches = kv_caches
    logger.info("Registering kv caches")
    future = send_lmcache_request(
        self.mq_client,
        RequestType.REGISTER_KV_CACHE,
        [self.instance_id, wrap_kv_caches(kv_caches)],
    )
    future.result()

shutdown

shutdown()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def shutdown(self):
    # Unregister kv cache
    logger.info("Unregistering kv caches")
    send_lmcache_request(
        self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
    ).result()

    self.mq_client.close()

submit_retrieve_request

submit_retrieve_request(
    request_id: str, op: LoadStoreOp, event: Event
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def submit_retrieve_request(
    self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
    keys = self._block_hashes_to_keys(op.block_hashes)
    future = send_lmcache_request(
        self.mq_client,
        RequestType.RETRIEVE,
        [keys, self.instance_id, op.block_ids, event.ipc_handle()],
    ).to_cuda_future()
    self.retrieve_futures[request_id] = (future, [])

submit_store_request

submit_store_request(
    request_id: str, op: LoadStoreOp, event: Event
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def submit_store_request(
    self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
    keys = self._block_hashes_to_keys(op.block_hashes)
    future = send_lmcache_request(
        self.mq_client,
        RequestType.STORE,
        [keys, self.instance_id, op.block_ids, event.ipc_handle()],
    ).to_cuda_future()
    self.store_futures[request_id] = (future, [])

LoadStoreOp dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@dataclass
class LoadStoreOp:
    block_hashes: list[bytes]
    block_ids: list[int]

    def __len__(self) -> int:
        return len(self.block_hashes)

    def __post_init__(self):
        assert len(self.block_hashes) == len(self.block_ids), (
            "The number of block hashes should be equal to the number of block ids "
            f"But got {len(self.block_hashes)} and {len(self.block_ids)}"
        )

block_hashes instance-attribute

block_hashes: list[bytes]

block_ids instance-attribute

block_ids: list[int]

__init__

__init__(
    block_hashes: list[bytes], block_ids: list[int]
) -> None

__len__

__len__() -> int
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def __len__(self) -> int:
    return len(self.block_hashes)

__post_init__

__post_init__()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def __post_init__(self):
    assert len(self.block_hashes) == len(self.block_ids), (
        "The number of block hashes should be equal to the number of block ids "
        f"But got {len(self.block_hashes)} and {len(self.block_ids)}"
    )