Skip to content

vllm.model_executor.layers.rotary_embedding.mrope

MRotaryEmbedding

Bases: RotaryEmbedding

Rotary Embedding with Multimodal Sections.

Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
class MRotaryEmbedding(RotaryEmbedding):
    """Rotary Embedding with Multimodal Sections."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
        mrope_section: Optional[list[int]] = None,
        mrope_interleaved: bool = False,
        # YaRN parameters.
        *,
        scaling_factor: Optional[float] = None,
        extrapolation_factor: float = 1,
        attn_factor: float = 1,
        beta_fast: int = 32,
        beta_slow: int = 1,
    ) -> None:
        self.scaling_factor = scaling_factor
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        if self.scaling_factor is not None:
            # Get n-d magnitude scaling corrected for interpolation
            self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
        else:
            self.mscale = 1.0

        # In Qwen2.5-VL, the maximum index value is related to the duration of
        # the input video. We enlarge max_position_embeddings to 4 times to get
        # a larger the cos and sin cache.
        self.cache_max_position_num = max_position_embeddings * 4
        super().__init__(
            head_size,
            rotary_dim,
            self.cache_max_position_num,
            base,
            is_neox_style,
            dtype,
        )

        self.mrope_section = mrope_section
        self.mrope_interleaved = mrope_interleaved
        if self.mrope_section:
            assert sum(self.mrope_section) == rotary_dim // 2

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        if self.scaling_factor is None:
            return super()._compute_inv_freq(base)
        return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base)

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        if self.scaling_factor is None:
            return super()._compute_cos_sin_cache()
        return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self)

    def forward_native(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward().

        Args:
            positions:
                [num_tokens,] (text only) or
                [3, num_tokens] (T/H/W positions with multimodal inputs)
            query: [num_tokens, num_heads * head_size]
            key: [num_tokens, num_kv_heads * head_size]
        """
        assert positions.ndim == 1 or positions.ndim == 2
        assert key is not None

        self._match_cos_sin_cache_dtype(query)
        num_tokens = positions.shape[-1]
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        if positions.ndim == 2:
            assert self.mrope_section
            if self.mrope_interleaved:
                cos = apply_interleaved_rope(cos, self.mrope_section)
                sin = apply_interleaved_rope(sin, self.mrope_section)
            else:
                cos = torch.cat(
                    [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
                    dim=-1,
                )
                sin = torch.cat(
                    [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
                    dim=-1,
                )

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., : self.rotary_dim]
        query_pass = query[..., self.rotary_dim :]
        query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., : self.rotary_dim]
        key_pass = key[..., self.rotary_dim :]
        key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    def forward_cuda(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert positions.ndim == 1 or positions.ndim == 2
        assert key is not None

        self._match_cos_sin_cache_dtype(query)
        num_tokens = positions.shape[-1]
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        query_shape = query.shape
        key_shape = key.shape
        if positions.ndim == 2:
            assert self.mrope_section

            q, k = triton_mrope(
                query,
                key,
                cos,
                sin,
                self.mrope_section,
                self.head_size,
                self.rotary_dim,
                self.mrope_interleaved,
            )

            return q.reshape(query_shape), k.reshape(key_shape)

        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., : self.rotary_dim]
        query_pass = query[..., self.rotary_dim :]
        query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., : self.rotary_dim]
        key_pass = key[..., self.rotary_dim :]
        key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    def forward_xpu(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        return self.forward_native(positions, query, key, offsets)

    def forward_cpu(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        return self.forward_native(positions, query, key, offsets)

    @classmethod
    def get_input_positions(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        second_per_grid_ts: Optional[list[float]],
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[list[list[int]], int]:
        """Get mrope input positions and delta value."""

        image_grid_thw = [] if image_grid_thw is None else image_grid_thw
        video_grid_thw = [] if video_grid_thw is None else video_grid_thw
        second_per_grid_ts = [] if second_per_grid_ts is None else second_per_grid_ts

        llm_positions, mrope_position_delta = cls.get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            second_per_grid_ts=second_per_grid_ts,
            context_len=context_len,
            seq_len=seq_len,
            audio_feature_lengths=audio_feature_lengths,
            use_audio_in_video=use_audio_in_video,
        )

        return llm_positions.tolist(), mrope_position_delta

    @staticmethod
    def get_next_input_positions(
        mrope_position_delta: int,
        context_len: int,
        seq_len: int,
    ) -> list[list[int]]:
        return [
            list(
                range(
                    context_len + mrope_position_delta, seq_len + mrope_position_delta
                )
            )
            for _ in range(3)
        ]

    @staticmethod
    def get_next_input_positions_tensor(
        out: np.ndarray,
        out_offset: int,
        mrope_position_delta: int,
        context_len: int,
        num_new_tokens: int,
    ):
        values = np.arange(
            mrope_position_delta + context_len,
            mrope_position_delta + context_len + num_new_tokens,
            dtype=out.dtype,
        )
        out[:, out_offset : out_offset + num_new_tokens] = values

attn_factor instance-attribute

attn_factor = attn_factor

beta_fast instance-attribute

beta_fast = beta_fast

beta_slow instance-attribute

beta_slow = beta_slow

cache_max_position_num instance-attribute

cache_max_position_num = max_position_embeddings * 4

extrapolation_factor instance-attribute

extrapolation_factor = extrapolation_factor

mrope_interleaved instance-attribute

mrope_interleaved = mrope_interleaved

mrope_section instance-attribute

mrope_section = mrope_section

mscale instance-attribute

mscale = float(
    yarn_get_mscale(scaling_factor) * attn_factor
)

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
    mrope_section: Optional[list[int]] = None,
    mrope_interleaved: bool = False,
    *,
    scaling_factor: Optional[float] = None,
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
    mrope_section: Optional[list[int]] = None,
    mrope_interleaved: bool = False,
    # YaRN parameters.
    *,
    scaling_factor: Optional[float] = None,
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
) -> None:
    self.scaling_factor = scaling_factor
    self.extrapolation_factor = extrapolation_factor
    self.attn_factor = attn_factor
    self.beta_fast = beta_fast
    self.beta_slow = beta_slow
    if self.scaling_factor is not None:
        # Get n-d magnitude scaling corrected for interpolation
        self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
    else:
        self.mscale = 1.0

    # In Qwen2.5-VL, the maximum index value is related to the duration of
    # the input video. We enlarge max_position_embeddings to 4 times to get
    # a larger the cos and sin cache.
    self.cache_max_position_num = max_position_embeddings * 4
    super().__init__(
        head_size,
        rotary_dim,
        self.cache_max_position_num,
        base,
        is_neox_style,
        dtype,
    )

    self.mrope_section = mrope_section
    self.mrope_interleaved = mrope_interleaved
    if self.mrope_section:
        assert sum(self.mrope_section) == rotary_dim // 2

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    if self.scaling_factor is None:
        return super()._compute_cos_sin_cache()
    return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self)

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    if self.scaling_factor is None:
        return super()._compute_inv_freq(base)
    return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base)

forward_cpu

forward_cpu(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def forward_cpu(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    return self.forward_native(positions, query, key, offsets)

forward_cuda

forward_cuda(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def forward_cuda(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert positions.ndim == 1 or positions.ndim == 2
    assert key is not None

    self._match_cos_sin_cache_dtype(query)
    num_tokens = positions.shape[-1]
    cos_sin = self.cos_sin_cache[positions]
    cos, sin = cos_sin.chunk(2, dim=-1)
    query_shape = query.shape
    key_shape = key.shape
    if positions.ndim == 2:
        assert self.mrope_section

        q, k = triton_mrope(
            query,
            key,
            cos,
            sin,
            self.mrope_section,
            self.head_size,
            self.rotary_dim,
            self.mrope_interleaved,
        )

        return q.reshape(query_shape), k.reshape(key_shape)

    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., : self.rotary_dim]
    query_pass = query[..., self.rotary_dim :]
    query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., : self.rotary_dim]
    key_pass = key[..., self.rotary_dim :]
    key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

forward_native

forward_native(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]

PyTorch-native implementation equivalent to forward().

Parameters:

Name Type Description Default
positions Tensor

[num_tokens,] (text only) or [3, num_tokens] (T/H/W positions with multimodal inputs)

required
query Tensor

[num_tokens, num_heads * head_size]

required
key Optional[Tensor]

[num_tokens, num_kv_heads * head_size]

None
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def forward_native(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """PyTorch-native implementation equivalent to forward().

    Args:
        positions:
            [num_tokens,] (text only) or
            [3, num_tokens] (T/H/W positions with multimodal inputs)
        query: [num_tokens, num_heads * head_size]
        key: [num_tokens, num_kv_heads * head_size]
    """
    assert positions.ndim == 1 or positions.ndim == 2
    assert key is not None

    self._match_cos_sin_cache_dtype(query)
    num_tokens = positions.shape[-1]
    cos_sin = self.cos_sin_cache[positions]
    cos, sin = cos_sin.chunk(2, dim=-1)
    if positions.ndim == 2:
        assert self.mrope_section
        if self.mrope_interleaved:
            cos = apply_interleaved_rope(cos, self.mrope_section)
            sin = apply_interleaved_rope(sin, self.mrope_section)
        else:
            cos = torch.cat(
                [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
                dim=-1,
            )
            sin = torch.cat(
                [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
                dim=-1,
            )

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., : self.rotary_dim]
    query_pass = query[..., self.rotary_dim :]
    query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key_shape = key.shape
    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., : self.rotary_dim]
    key_pass = key[..., self.rotary_dim :]
    key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

forward_xpu

forward_xpu(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def forward_xpu(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    return self.forward_native(positions, query, key, offsets)

get_input_positions classmethod

get_input_positions(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Optional[
        Union[list[list[int]], Tensor]
    ],
    video_grid_thw: Optional[
        Union[list[list[int]], Tensor]
    ],
    second_per_grid_ts: Optional[list[float]],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]

Get mrope input positions and delta value.

Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@classmethod
def get_input_positions(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
    video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
    second_per_grid_ts: Optional[list[float]],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[torch.Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]:
    """Get mrope input positions and delta value."""

    image_grid_thw = [] if image_grid_thw is None else image_grid_thw
    video_grid_thw = [] if video_grid_thw is None else video_grid_thw
    second_per_grid_ts = [] if second_per_grid_ts is None else second_per_grid_ts

    llm_positions, mrope_position_delta = cls.get_input_positions_tensor(
        input_tokens=input_tokens,
        hf_config=hf_config,
        image_grid_thw=image_grid_thw,
        video_grid_thw=video_grid_thw,
        second_per_grid_ts=second_per_grid_ts,
        context_len=context_len,
        seq_len=seq_len,
        audio_feature_lengths=audio_feature_lengths,
        use_audio_in_video=use_audio_in_video,
    )

    return llm_positions.tolist(), mrope_position_delta

get_next_input_positions staticmethod

get_next_input_positions(
    mrope_position_delta: int,
    context_len: int,
    seq_len: int,
) -> list[list[int]]
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@staticmethod
def get_next_input_positions(
    mrope_position_delta: int,
    context_len: int,
    seq_len: int,
) -> list[list[int]]:
    return [
        list(
            range(
                context_len + mrope_position_delta, seq_len + mrope_position_delta
            )
        )
        for _ in range(3)
    ]

get_next_input_positions_tensor staticmethod

get_next_input_positions_tensor(
    out: ndarray,
    out_offset: int,
    mrope_position_delta: int,
    context_len: int,
    num_new_tokens: int,
)
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@staticmethod
def get_next_input_positions_tensor(
    out: np.ndarray,
    out_offset: int,
    mrope_position_delta: int,
    context_len: int,
    num_new_tokens: int,
):
    values = np.arange(
        mrope_position_delta + context_len,
        mrope_position_delta + context_len + num_new_tokens,
        dtype=out.dtype,
    )
    out[:, out_offset : out_offset + num_new_tokens] = values

_triton_mrope_forward

_triton_mrope_forward(
    q_ptr,
    k_ptr,
    cos,
    sin,
    num_tokens,
    n_qh: constexpr,
    n_kh: constexpr,
    hd: constexpr,
    rd: constexpr,
    pad_n_qh: constexpr,
    pad_n_kh: constexpr,
    pad_hd: constexpr,
    mrope_section_t: constexpr,
    mrope_section_h: constexpr,
    mrope_section_w: constexpr,
    is_interleaved: constexpr,
)
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@triton.jit
def _triton_mrope_forward(
    q_ptr,
    k_ptr,
    cos,
    sin,
    num_tokens,
    n_qh: tl.constexpr,
    n_kh: tl.constexpr,
    hd: tl.constexpr,
    rd: tl.constexpr,
    pad_n_qh: tl.constexpr,
    pad_n_kh: tl.constexpr,
    pad_hd: tl.constexpr,
    mrope_section_t: tl.constexpr,
    mrope_section_h: tl.constexpr,
    mrope_section_w: tl.constexpr,
    is_interleaved: tl.constexpr,
):
    # Adapted from
    # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
    # This version supports flatten input tensors from vllm
    # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
    # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
    pid = tl.program_id(0)
    # locate start address
    q_ptr = q_ptr + pid * (n_qh * hd)
    k_ptr = k_ptr + pid * (n_kh * hd)

    # ####################################################################
    # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
    # m of this program instance
    # ####################################################################
    # Note: cos and sin now have shape (3, num_tokens, head_dim // 2)

    # Updated stride calculation for half head_dim
    half_rd = rd // 2
    t_cos = cos + pid * half_rd
    h_cos = t_cos + num_tokens * half_rd
    w_cos = h_cos + num_tokens * half_rd
    t_sin = sin + pid * half_rd
    h_sin = t_sin + num_tokens * half_rd
    w_sin = h_sin + num_tokens * half_rd

    # Updated offsets for half head_dim
    cos_offsets = tl.arange(0, pad_hd // 2)
    if is_interleaved:
        h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
        w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
        t_mask = ~(h_mask | w_mask)
    else:
        t_end = mrope_section_t
        h_end = t_end + mrope_section_h
        t_mask = cos_offsets < mrope_section_t
        h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
        w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)

    t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
    h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
    w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
    t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
    h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
    w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)

    cos_row = t_cos_row + h_cos_row + w_cos_row
    sin_row = t_sin_row + h_sin_row + w_sin_row

    # ####################################################################
    # Load the left and right half of q and k for the current
    # program instance (i.e. for the current token) separately
    # ####################################################################
    # left half of the head
    first_half_q_offsets = (
        tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
    )
    first_half_k_offsets = (
        tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
    )
    first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
        tl.arange(0, pad_hd // 2)[None, :] < rd // 2
    )
    first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
        tl.arange(0, pad_hd // 2)[None, :] < rd // 2
    )

    q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
        sin_row.dtype
    )
    k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
        sin_row.dtype
    )

    # right half of the head
    second_half_q_offsets = first_half_q_offsets + (rd // 2)
    second_half_k_offsets = first_half_k_offsets + (rd // 2)
    second_q_mask = first_q_mask
    second_k_mask = first_k_mask

    q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
        sin_row.dtype
    )
    k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
        sin_row.dtype
    )

    # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
    # Since cos and sin are now half-size,
    # we use the same cos_row and sin_row for both halves
    new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
    tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
    new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
    tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)

    new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
    tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
    new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
    tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)

apply_interleaved_rope

apply_interleaved_rope(
    x: Tensor, mrope_section: list[int]
) -> Tensor

Apply interleaved MRoPE to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to interleaved [THTHWHTHW...TT], preserving frequency continuity.

Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
    """Apply interleaved MRoPE to 3D rotary embeddings.
    Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
    interleaved [THTHWHTHW...TT], preserving frequency continuity.
    """
    x_t = x[0].clone()
    x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
    x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
    return x_t

triton_mrope

triton_mrope(
    q: Tensor,
    k: Tensor,
    cos: Tensor,
    sin: Tensor,
    mrope_section: list[int],
    head_size: int,
    rotary_dim: int,
    mrope_interleaved: bool,
) -> tuple[Tensor, Tensor]

Qwen2VL mrope kernel.

Parameters:

Name Type Description Default
q Tensor

[num_tokens, num_heads * head_size]

required
k Tensor

[num_tokens, num_kv_heads * head_size]

required
cos Tensor

[3, num_tokens, head_size //2 ] (T/H/W positions with multimodal inputs)

required
sin Tensor

[3, num_tokens, head_size //2 ] (T/H/W positions with multimodal inputs)

required
mrope_section list[int]

[t, h, w]

required
head_size int

int

required
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def triton_mrope(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    mrope_section: list[int],
    head_size: int,
    rotary_dim: int,
    mrope_interleaved: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Qwen2VL mrope kernel.

    Args:
        q: [num_tokens, num_heads * head_size]
        k: [num_tokens, num_kv_heads * head_size]
        cos: [3, num_tokens, head_size //2 ]
            (T/H/W positions with multimodal inputs)
        sin: [3, num_tokens, head_size //2 ]
            (T/H/W positions with multimodal inputs)
        mrope_section: [t, h, w]
        head_size: int
    """
    n_row, n_q_head_head_dim = q.shape
    n_q_head = n_q_head_head_dim // head_size
    n_kv_head = k.shape[1] // head_size
    pad_hd = triton.next_power_of_2(head_size)
    pad_n_q_head = triton.next_power_of_2(n_q_head)
    pad_n_kv_head = triton.next_power_of_2(n_kv_head)

    # ensure tensors passed into the kernel are contiguous.
    # It will be no-op if they are already contiguous
    q = q.contiguous()
    k = k.contiguous()
    cos = cos.contiguous()
    sin = sin.contiguous()

    _triton_mrope_forward[(n_row,)](
        q,
        k,
        cos,
        sin,
        n_row,
        n_q_head,
        n_kv_head,
        head_size,
        rotary_dim,
        pad_n_q_head,
        pad_n_kv_head,
        pad_hd,
        mrope_section[0],
        mrope_section[1],
        mrope_section[2],
        mrope_interleaved,
    )
    return q, k