Skip to content

vllm.model_executor.models.qwen2_5_omni_thinker

Inference-only Qwen2.5-Omni model (thinker part).

logger module-attribute

logger = init_logger(__name__)

Qwen2_5OmniAudioFeatureInputs

Bases: TensorSchema

Dimensions
  • na: Number of audios
  • nmb: Number of mel bins
  • msl: Maximum sequence length
  • tsl: Total sequence length
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - na: Number of audios
        - nmb: Number of mel bins
        - msl: Maximum sequence length
        - tsl: Total sequence length
    """

    type: Literal["audio_features"]
    input_features: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("nmb", "tsl"),
    ]

    feature_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("na", "msl"),
    ]

feature_attention_mask instance-attribute

feature_attention_mask: Annotated[
    Tensor, TensorShape(na, msl)
]

input_features instance-attribute

input_features: Annotated[
    Union[Tensor, list[Tensor]], TensorShape(nmb, tsl)
]

type instance-attribute

type: Literal['audio_features']

Qwen2_5OmniConditionalGenerationMixin

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
class Qwen2_5OmniConditionalGenerationMixin:
    def _validate_and_reshape_mm_tensor(
        self, mm_input: object, name: str, dim: int = 0
    ) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if dim == 0:
                return mm_input.reshape(-1, *mm_input.shape[2:])
            return torch.concat(list(mm_input), dim=dim)
        else:
            return torch.concat(mm_input, dim=dim)

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
        input_audio_features = kwargs.pop("input_audio_features", None)
        audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
        feature_attention_mask = kwargs.pop("feature_attention_mask", None)
        if input_audio_features is None:
            return None
        input_audio_features = self._validate_and_reshape_mm_tensor(
            input_audio_features, "input_audio_features", dim=1
        )
        if feature_attention_mask is not None:
            feature_attention_mask = self._validate_and_reshape_mm_tensor(
                feature_attention_mask, "feature_attention_mask"
            )
        if not isinstance(input_audio_features, (torch.Tensor, list)):
            raise ValueError(
                "Incorrect type of audio input features. "
                f"Got type: {type(input_audio_features)}"
            )
        return Qwen2_5OmniAudioFeatureInputs(
            type="audio_features",
            input_features=input_audio_features,
            audio_feature_lengths=audio_feature_lengths,
            feature_attention_mask=feature_attention_mask,
        )

    def _parse_and_validate_image_input(
        self,
        **kwargs: dict[str, Any],
    ) -> Optional[Qwen2_5_VLImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values"
            )
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw"
            )

            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError(
                    "Incorrect type of image pixel values. "
                    f"Got type: {type(pixel_values)}"
                )

            return Qwen2_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds"
            )
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw"
            )

            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError(
                    "Incorrect type of image embeddings. "
                    f"Got type: {type(image_embeds)}"
                )
            return Qwen2_5_VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
        self,
        **kwargs: dict[str, Any],
    ) -> Optional[Qwen2_5_VLVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos, "video pixel values"
            )
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw"
            )

            return Qwen2_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds"
            )
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw"
            )

            if not isinstance(video_embeds, torch.Tensor):
                raise ValueError(
                    "Incorrect type of video embeddings. "
                    f"Got type: {type(video_embeds)}"
                )
            return Qwen2_5_VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_audio_input(
        self,
        audio_input: Qwen2_5OmniAudioFeatureInputs,
        audio_hashes: list[str] = None,
        cached_audio_features: torch.Tensor = None,
    ) -> torch.Tensor:
        input_features = audio_input["input_features"]
        audio_feature_lengths = audio_input["audio_feature_lengths"]
        if input_features.ndim == 3:
            assert input_features.shape[0] == 1
            input_features = input_features.squeeze(0)
        if audio_feature_lengths.ndim == 2:
            assert (
                audio_feature_lengths.shape[0] == 1
                or audio_feature_lengths.shape[1] == 1
            )
            if audio_feature_lengths.shape[0] == 1:
                audio_feature_lengths = audio_feature_lengths.squeeze(0)
            else:
                audio_feature_lengths = audio_feature_lengths.squeeze(1)

        audio_feat_lengths, audio_output_lengths = (
            self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths)
        )

        audio_outputs = self.audio_tower(
            input_features.to(self.audio_tower.dtype),
            feature_lens=audio_feature_lengths,
            aftercnn_lens=audio_feat_lengths,
        )
        return audio_outputs.last_hidden_state.split(audio_output_lengths.tolist())

    def _process_image_input(
        self, image_input: Qwen2_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
        if image_input["type"] == "image_embeds":
            return image_input["image_embeds"].type(self.visual.dtype)

        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        pixel_values = image_input["pixel_values"].type(self.visual.dtype)
        image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
        sizes = grid_thw.prod(-1) // merge_size // merge_size

        return image_embeds.split(sizes.tolist())

    def _process_video_input(
        self,
        video_input: Qwen2_5_VLVideoInputs,
        video_hashes: list[str] = None,
        cached_video_embeds: torch.Tensor = None,
    ) -> torch.Tensor:
        if video_input["type"] == "video_embeds":
            return video_input["video_embeds"].type(self.visual.dtype)

        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
        video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
        sizes = grid_thw.prod(-1) // merge_size // merge_size

        return video_embeds.split(sizes.tolist())

_parse_and_validate_audio_input

_parse_and_validate_audio_input(
    **kwargs: object,
) -> Optional[Qwen2_5OmniAudioFeatureInputs]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _parse_and_validate_audio_input(
    self, **kwargs: object
) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
    input_audio_features = kwargs.pop("input_audio_features", None)
    audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
    feature_attention_mask = kwargs.pop("feature_attention_mask", None)
    if input_audio_features is None:
        return None
    input_audio_features = self._validate_and_reshape_mm_tensor(
        input_audio_features, "input_audio_features", dim=1
    )
    if feature_attention_mask is not None:
        feature_attention_mask = self._validate_and_reshape_mm_tensor(
            feature_attention_mask, "feature_attention_mask"
        )
    if not isinstance(input_audio_features, (torch.Tensor, list)):
        raise ValueError(
            "Incorrect type of audio input features. "
            f"Got type: {type(input_audio_features)}"
        )
    return Qwen2_5OmniAudioFeatureInputs(
        type="audio_features",
        input_features=input_audio_features,
        audio_feature_lengths=audio_feature_lengths,
        feature_attention_mask=feature_attention_mask,
    )

_parse_and_validate_image_input

_parse_and_validate_image_input(
    **kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLImageInputs]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _parse_and_validate_image_input(
    self,
    **kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLImageInputs]:
    pixel_values = kwargs.pop("pixel_values", None)
    image_embeds = kwargs.pop("image_embeds", None)
    image_grid_thw = kwargs.pop("image_grid_thw", None)

    if pixel_values is None and image_embeds is None:
        return None

    if pixel_values is not None:
        pixel_values = self._validate_and_reshape_mm_tensor(
            pixel_values, "image pixel values"
        )
        image_grid_thw = self._validate_and_reshape_mm_tensor(
            image_grid_thw, "image grid_thw"
        )

        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError(
                "Incorrect type of image pixel values. "
                f"Got type: {type(pixel_values)}"
            )

        return Qwen2_5_VLImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
        )

    if image_embeds is not None:
        image_embeds = self._validate_and_reshape_mm_tensor(
            image_embeds, "image embeds"
        )
        image_grid_thw = self._validate_and_reshape_mm_tensor(
            image_grid_thw, "image grid_thw"
        )

        if not isinstance(image_embeds, torch.Tensor):
            raise ValueError(
                "Incorrect type of image embeddings. "
                f"Got type: {type(image_embeds)}"
            )
        return Qwen2_5_VLImageEmbeddingInputs(
            type="image_embeds",
            image_embeds=image_embeds,
            image_grid_thw=image_grid_thw,
        )

_parse_and_validate_video_input

_parse_and_validate_video_input(
    **kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLVideoInputs]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _parse_and_validate_video_input(
    self,
    **kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLVideoInputs]:
    pixel_values_videos = kwargs.pop("pixel_values_videos", None)
    video_embeds = kwargs.pop("video_embeds", None)
    video_grid_thw = kwargs.pop("video_grid_thw", None)

    if pixel_values_videos is None and video_embeds is None:
        return None

    if pixel_values_videos is not None:
        pixel_values_videos = self._validate_and_reshape_mm_tensor(
            pixel_values_videos, "video pixel values"
        )
        video_grid_thw = self._validate_and_reshape_mm_tensor(
            video_grid_thw, "video grid_thw"
        )

        return Qwen2_5_VLVideoPixelInputs(
            type="pixel_values_videos",
            pixel_values_videos=pixel_values_videos,
            video_grid_thw=video_grid_thw,
        )

    if video_embeds is not None:
        video_embeds = self._validate_and_reshape_mm_tensor(
            video_embeds, "video embeds"
        )
        video_grid_thw = self._validate_and_reshape_mm_tensor(
            video_grid_thw, "video grid_thw"
        )

        if not isinstance(video_embeds, torch.Tensor):
            raise ValueError(
                "Incorrect type of video embeddings. "
                f"Got type: {type(video_embeds)}"
            )
        return Qwen2_5_VLVideoEmbeddingInputs(
            type="video_embeds",
            video_embeds=video_embeds,
            video_grid_thw=video_grid_thw,
        )

_process_audio_input

_process_audio_input(
    audio_input: Qwen2_5OmniAudioFeatureInputs,
    audio_hashes: list[str] = None,
    cached_audio_features: Tensor = None,
) -> Tensor
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _process_audio_input(
    self,
    audio_input: Qwen2_5OmniAudioFeatureInputs,
    audio_hashes: list[str] = None,
    cached_audio_features: torch.Tensor = None,
) -> torch.Tensor:
    input_features = audio_input["input_features"]
    audio_feature_lengths = audio_input["audio_feature_lengths"]
    if input_features.ndim == 3:
        assert input_features.shape[0] == 1
        input_features = input_features.squeeze(0)
    if audio_feature_lengths.ndim == 2:
        assert (
            audio_feature_lengths.shape[0] == 1
            or audio_feature_lengths.shape[1] == 1
        )
        if audio_feature_lengths.shape[0] == 1:
            audio_feature_lengths = audio_feature_lengths.squeeze(0)
        else:
            audio_feature_lengths = audio_feature_lengths.squeeze(1)

    audio_feat_lengths, audio_output_lengths = (
        self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths)
    )

    audio_outputs = self.audio_tower(
        input_features.to(self.audio_tower.dtype),
        feature_lens=audio_feature_lengths,
        aftercnn_lens=audio_feat_lengths,
    )
    return audio_outputs.last_hidden_state.split(audio_output_lengths.tolist())

_process_image_input

_process_image_input(
    image_input: Qwen2_5_VLImageInputs,
) -> tuple[Tensor, ...]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _process_image_input(
    self, image_input: Qwen2_5_VLImageInputs
) -> tuple[torch.Tensor, ...]:
    if image_input["type"] == "image_embeds":
        return image_input["image_embeds"].type(self.visual.dtype)

    grid_thw = image_input["image_grid_thw"]
    assert grid_thw.ndim == 2

    pixel_values = image_input["pixel_values"].type(self.visual.dtype)
    image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
    # Split concatenated embeddings for each image item.
    merge_size = self.visual.spatial_merge_size
    sizes = grid_thw.prod(-1) // merge_size // merge_size

    return image_embeds.split(sizes.tolist())

_process_video_input

_process_video_input(
    video_input: Qwen2_5_VLVideoInputs,
    video_hashes: list[str] = None,
    cached_video_embeds: Tensor = None,
) -> Tensor
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _process_video_input(
    self,
    video_input: Qwen2_5_VLVideoInputs,
    video_hashes: list[str] = None,
    cached_video_embeds: torch.Tensor = None,
) -> torch.Tensor:
    if video_input["type"] == "video_embeds":
        return video_input["video_embeds"].type(self.visual.dtype)

    grid_thw = video_input["video_grid_thw"]
    assert grid_thw.ndim == 2

    pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
    video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
    # Split concatenated embeddings for each video item.
    merge_size = self.visual.spatial_merge_size
    sizes = grid_thw.prod(-1) // merge_size // merge_size

    return video_embeds.split(sizes.tolist())

_validate_and_reshape_mm_tensor

_validate_and_reshape_mm_tensor(
    mm_input: object, name: str, dim: int = 0
) -> Tensor
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _validate_and_reshape_mm_tensor(
    self, mm_input: object, name: str, dim: int = 0
) -> torch.Tensor:
    if not isinstance(mm_input, (torch.Tensor, list)):
        raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
    if isinstance(mm_input, torch.Tensor):
        if dim == 0:
            return mm_input.reshape(-1, *mm_input.shape[2:])
        return torch.concat(list(mm_input), dim=dim)
    else:
        return torch.concat(mm_input, dim=dim)

Qwen2_5OmniThinkerDummyInputsBuilder

Bases: BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
class Qwen2_5OmniThinkerDummyInputsBuilder(
    BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]
):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_processor = self.info.get_hf_processor()

        audio_token: str = hf_processor.audio_token
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token

        return (
            audio_token * num_audios
            + image_token * num_images
            + video_token * num_videos
        )

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        feature_extractor = self.info.get_feature_extractor()

        target_audio_length = (
            min(
                feature_extractor.chunk_length,
                30,
            )
            * feature_extractor.sampling_rate
        )
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )

        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None
        audio_overrides = mm_options.get("audio") if mm_options else None

        mm_data = {
            "audio": self._get_dummy_audios(
                length=target_audio_length,
                num_audios=num_audios,
                overrides=audio_overrides,
            ),
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
                overrides=video_overrides,
            ),
        }

        return mm_data

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Optional[
        Mapping[str, BaseDummyOptions]
    ] = None,
) -> MultiModalDataDict
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
    num_audios = mm_counts.get("audio", 0)
    num_images = mm_counts.get("image", 0)
    num_videos = mm_counts.get("video", 0)

    feature_extractor = self.info.get_feature_extractor()

    target_audio_length = (
        min(
            feature_extractor.chunk_length,
            30,
        )
        * feature_extractor.sampling_rate
    )
    target_width, target_height = self.info.get_image_size_with_most_features()
    target_num_frames = self.info.get_num_frames_with_most_features(
        seq_len, mm_counts
    )

    image_overrides = mm_options.get("image") if mm_options else None
    video_overrides = mm_options.get("video") if mm_options else None
    audio_overrides = mm_options.get("audio") if mm_options else None

    mm_data = {
        "audio": self._get_dummy_audios(
            length=target_audio_length,
            num_audios=num_audios,
            overrides=audio_overrides,
        ),
        "image": self._get_dummy_images(
            width=target_width,
            height=target_height,
            num_images=num_images,
            overrides=image_overrides,
        ),
        "video": self._get_dummy_videos(
            width=target_width,
            height=target_height,
            num_frames=target_num_frames,
            num_videos=num_videos,
            overrides=video_overrides,
        ),
    }

    return mm_data

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    num_audios = mm_counts.get("audio", 0)
    num_images = mm_counts.get("image", 0)
    num_videos = mm_counts.get("video", 0)

    hf_processor = self.info.get_hf_processor()

    audio_token: str = hf_processor.audio_token
    image_token: str = hf_processor.image_token
    video_token: str = hf_processor.video_token

    return (
        audio_token * num_audios
        + image_token * num_images
        + video_token * num_videos
    )

Qwen2_5OmniThinkerForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsMRoPE, Qwen2_5OmniConditionalGenerationMixin

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2_5OmniThinkerMultiModalProcessor,
    info=Qwen2_5OmniThinkerProcessingInfo,
    dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
)
class Qwen2_5OmniThinkerForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    SupportsLoRA,
    SupportsMRoPE,
    Qwen2_5OmniConditionalGenerationMixin,
):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "thinker.lm_head.": "language_model.lm_head.",
            "thinker.model.": "language_model.model.",
            "thinker.": "",
        }
    )
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "attn.qkv": [
            "attn.q",
            "attn.k",
            "attn.v",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|vision_start|><|IMAGE|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|VIDEO|><|vision_end|>"
        if modality.startswith("audio"):
            return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only image, video or audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        thinker_config: Qwen2_5OmniThinkerConfig = (
            vllm_config.model_config.hf_config.thinker_config
        )
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = thinker_config
        self.multimodal_config = multimodal_config

        # force "use_flash_attention_2=True" to audio tower to align
        # the results.
        if flash_attn is not None:
            audio_config = thinker_config.audio_config
            audio_config._attn_implementation_autoset = True
            audio_config._attn_implementation = "flash_attention_2"
        else:
            logger.warning(
                "flash_attn is not available, the model may not yield the "
                "exactly same result as the transformers implementation "
                "in the audio tower part."
            )

        if multimodal_config.get_limit_per_prompt("audio"):
            self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
        else:
            self.audio_tower = None

        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
            self.visual = Qwen2_5_VisionTransformer(
                vision_config=thinker_config.vision_config,
                norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "visual"),
            )
        else:
            self.visual = None

        self.quant_config = quant_config
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            hf_config=thinker_config.text_config,
            architectures=["Qwen2ForCausalLM"],
        )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
            if (
                input_key in ("input_audio_features")
                and "audio" not in mm_input_by_modality
            ):
                mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
                    **kwargs
                )
        return mm_input_by_modality

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    @classmethod
    def get_mrope_input_positions(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: Optional[list[float]] = None,
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value (Qwen2.5-Omni version).

        Differences from MRotaryEmbedding:
            1. Add audio support (and related `audio_feature_lengths`).
            2. Add `use_audio_in_video` option to read audio from video inputs.
                In this case, audio and vision position ids will be split into
                chunks and interleaved.

        Example:

            (V_i are vision position ids, A_i are audio position ids)

            |V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
            |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
        """

        # TODO(fyabc): refactor and share more code with
        #  _vl_get_input_positions_tensor.

        thinker_config = hf_config.thinker_config
        audio_token_id = thinker_config.audio_token_index
        image_token_id = thinker_config.image_token_index
        video_token_id = thinker_config.video_token_index
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        vision_start_token_id = thinker_config.vision_start_token_id
        vision_end_token_id = thinker_config.vision_end_token_id
        seconds_per_chunk = thinker_config.seconds_per_chunk
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(
            thinker_config.vision_config, "tokens_per_second", 25
        )

        if isinstance(image_grid_thw, list):
            image_grid_thw = torch.tensor(image_grid_thw)
        if isinstance(video_grid_thw, list):
            video_grid_thw = torch.tensor(video_grid_thw)

        src_item = input_tokens
        audio_seqlens = audio_feature_lengths
        if not second_per_grid_ts:
            second_per_grid_ts = [1] * video_grid_thw.shape[0]
        audio_idx = 0
        video_idx = 0
        image_idx = 0
        new_src_item: list[int] = []
        llm_pos_ids_list: list[torch.Tensor] = []

        idx = 0
        while idx < len(src_item):
            new_src_item_len = len(new_src_item)
            start_idx = (
                llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            )
            if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]:
                if use_audio_in_video and idx > 0:
                    if (
                        src_item[idx] == vision_end_token_id
                        and src_item[idx - 1] == audio_end_token_id
                    ):
                        # processing the <|audio_eos|> before <|vision_eos|>
                        start_idx -= 1
                    elif (
                        src_item[idx] == audio_start_token_id
                        and src_item[idx - 1] == vision_start_token_id
                    ):
                        # processing the <|audio_bos|> after <|vision_eos|>
                        start_idx -= 1
                new_src_item.append(src_item[idx])
                llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
                llm_pos_ids_list.append(llm_pos_ids)
            elif src_item[idx] == audio_token_id:
                assert audio_seqlens is not None
                audio_seqlen = audio_seqlens[audio_idx]
                place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
                new_src_item.extend([audio_token_id] * place_num)
                llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
                llm_pos_ids_list.append(llm_pos_ids)
                audio_idx += 1
            elif src_item[idx] == image_token_id:
                grid_t = image_grid_thw[image_idx][0]
                grid_hs = image_grid_thw[:, 1]
                grid_ws = image_grid_thw[:, 2]
                t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
                llm_pos_ids = get_llm_pos_ids_for_vision(
                    start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
                )
                llm_pos_ids_list.append(llm_pos_ids)
                vision_seqlen = image_grid_thw[image_idx].prod() // (
                    spatial_merge_size**2
                )
                new_src_item.extend([image_token_id] * vision_seqlen)
                image_idx += 1
            elif src_item[idx] == video_token_id and not use_audio_in_video:
                grid_t = video_grid_thw[video_idx][0]
                grid_hs = video_grid_thw[:, 1]
                grid_ws = video_grid_thw[:, 2]
                t_index = (
                    torch.arange(grid_t)
                    * second_per_grid_ts[video_idx]
                    * tokens_per_second
                ).long()
                llm_pos_ids = get_llm_pos_ids_for_vision(
                    start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
                )
                llm_pos_ids_list.append(llm_pos_ids)
                vision_seqlen = video_grid_thw[video_idx].prod() // (
                    spatial_merge_size**2
                )
                new_src_item.extend([video_token_id] * vision_seqlen)
                video_idx += 1
            else:
                # read audio from video
                assert audio_seqlens is not None
                audio_seqlen = audio_seqlens[audio_idx]
                vision_seqlen = video_grid_thw[video_idx].prod() // (
                    spatial_merge_size**2
                )
                grid_t = video_grid_thw[video_idx][0]
                grid_h = video_grid_thw[video_idx][1]
                grid_w = video_grid_thw[video_idx][2]
                grid_hs = video_grid_thw[:, 1]
                grid_ws = video_grid_thw[:, 2]
                t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
                t_index = (
                    torch.arange(grid_t)
                    * second_per_grid_ts[video_idx]
                    * tokens_per_second
                ).long()
                t_index_split_chunk = split_list_into_ranges(
                    t_index, t_ntoken_per_chunk
                )
                place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
                pure_audio_len = place_num - 2
                added_audio_len = 0
                audio_llm_pos_ids_list: list[torch.Tensor] = []
                for t_chunk in t_index_split_chunk:
                    vision_ntoken_per_chunk = (
                        len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
                    )
                    new_src_item.extend([video_token_id] * vision_ntoken_per_chunk)
                    vision_llm_pos_ids_list = get_llm_pos_ids_for_vision(
                        start_idx,
                        video_idx,
                        spatial_merge_size,
                        t_chunk,
                        grid_hs,
                        grid_ws,
                    ).split(1, dim=1)
                    llm_pos_ids_list.extend(vision_llm_pos_ids_list)
                    new_src_item.extend(
                        min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
                        * [audio_token_id]
                    )
                    audio_start_idx = (
                        start_idx
                        if len(audio_llm_pos_ids_list) == 0
                        else audio_llm_pos_ids_list[-1][0].item() + 1
                    )
                    if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0:
                        audio_llm_pos_ids_list = (
                            torch.arange(
                                min(
                                    t_ntoken_per_chunk, pure_audio_len - added_audio_len
                                )
                            ).expand(3, -1)
                            + audio_start_idx
                        ).split(1, dim=1)
                    else:
                        audio_llm_pos_ids_list = []
                    added_audio_len += min(
                        t_ntoken_per_chunk, pure_audio_len - added_audio_len
                    )
                    llm_pos_ids_list.extend(audio_llm_pos_ids_list)
                if added_audio_len < pure_audio_len:
                    new_src_item.extend(
                        (pure_audio_len - added_audio_len) * [audio_token_id]
                    )
                    audio_llm_pos_ids_list = (
                        torch.arange(pure_audio_len - added_audio_len).expand(3, -1)
                        + llm_pos_ids_list[-1].max()
                        + 1
                    ).split(1, dim=1)
                    llm_pos_ids_list.extend(audio_llm_pos_ids_list)
                audio_idx += 1
                video_idx += 1
            # move to the next token
            idx += len(new_src_item) - new_src_item_len

        llm_positions = torch.cat(llm_pos_ids_list, dim=1)
        mrope_position_delta = (
            torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
        )
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not mm_input_by_modality:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor corresponding to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += vision_embeddings
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                multimodal_embeddings += video_embeddings
            if modality == "audio":
                audio_embeddings = self._process_audio_input(multimodal_input)
                multimodal_embeddings += audio_embeddings
        return multimodal_embeddings

    # TODO (ywang96): support overlapping modality embeddings so that
    # `use_audio_in_video` will work on V1.
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)

        return super().get_input_embeddings(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

    def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTensors]:
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        image_input = self._parse_and_validate_image_input(**kwargs)
        video_input = self._parse_and_validate_video_input(**kwargs)

        if audio_input is None and image_input is None and video_input is None:
            return None

        multimodal_embeddings: list[tuple[NestedTensors, str]] = []

        if audio_input is not None:
            audio_embeds = self._process_audio_input(audio_input)
            multimodal_embeddings.append((audio_embeds, "audio"))
        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            multimodal_embeddings.append((image_embeds, "image"))
        if video_input is not None:
            video_embeds = self._process_video_input(video_input)
            multimodal_embeddings.append((video_embeds, "video"))
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        skip_prefixes = ["talker.", "token2wav."]
        if self.audio_tower is None:
            skip_prefixes.extend(["audio_tower."])
        if self.visual is None:
            skip_prefixes.extend(["visual."])

        loader = AutoWeightsLoader(
            self,
            skip_prefixes=skip_prefixes,
        )
        loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

        return loaded_weights

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="merger.",
            tower_model=["visual.", "audio_tower."],
        )

audio_tower instance-attribute

audio_tower = Qwen2_5OmniAudioEncoder(audio_config)

config instance-attribute

config = thinker_config

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_prefix={
        "thinker.lm_head.": "language_model.lm_head.",
        "thinker.model.": "language_model.model.",
        "thinker.": "",
    }
)

language_model instance-attribute

language_model = init_vllm_registered_model(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "language_model"),
    hf_config=text_config,
    architectures=["Qwen2ForCausalLM"],
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "attn.qkv": ["attn.q", "attn.k", "attn.v"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

quant_config instance-attribute

quant_config = quant_config

visual instance-attribute

visual = Qwen2_5_VisionTransformer(
    vision_config=vision_config,
    norm_eps=getattr(text_config, "rms_norm_eps", 1e-06),
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "visual"),
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    thinker_config: Qwen2_5OmniThinkerConfig = (
        vllm_config.model_config.hf_config.thinker_config
    )
    quant_config = vllm_config.quant_config
    multimodal_config = vllm_config.model_config.multimodal_config
    self.config = thinker_config
    self.multimodal_config = multimodal_config

    # force "use_flash_attention_2=True" to audio tower to align
    # the results.
    if flash_attn is not None:
        audio_config = thinker_config.audio_config
        audio_config._attn_implementation_autoset = True
        audio_config._attn_implementation = "flash_attention_2"
    else:
        logger.warning(
            "flash_attn is not available, the model may not yield the "
            "exactly same result as the transformers implementation "
            "in the audio tower part."
        )

    if multimodal_config.get_limit_per_prompt("audio"):
        self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
    else:
        self.audio_tower = None

    if multimodal_config.get_limit_per_prompt(
        "image"
    ) or multimodal_config.get_limit_per_prompt("video"):
        self.visual = Qwen2_5_VisionTransformer(
            vision_config=thinker_config.vision_config,
            norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "visual"),
        )
    else:
        self.visual = None

    self.quant_config = quant_config
    self.language_model = init_vllm_registered_model(
        vllm_config=vllm_config,
        prefix=maybe_prefix(prefix, "language_model"),
        hf_config=thinker_config.text_config,
        architectures=["Qwen2ForCausalLM"],
    )

    self.make_empty_intermediate_tensors = (
        self.language_model.make_empty_intermediate_tensors
    )

_parse_and_validate_multimodal_inputs

_parse_and_validate_multimodal_inputs(
    **kwargs: object,
) -> dict
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
    mm_input_by_modality = {}

    # Preserve the order of modalities if there are multiple of them
    # from the order of kwargs.
    for input_key in kwargs:
        if (
            input_key in ("pixel_values", "image_embeds")
            and "image" not in mm_input_by_modality
        ):
            mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                **kwargs
            )
        if (
            input_key in ("pixel_values_videos", "video_embeds")
            and "video" not in mm_input_by_modality
        ):
            mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                **kwargs
            )
        if (
            input_key in ("input_audio_features")
            and "audio" not in mm_input_by_modality
        ):
            mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
                **kwargs
            )
    return mm_input_by_modality

compute_logits

compute_logits(hidden_states: Tensor) -> Optional[Tensor]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
    return self.language_model.compute_logits(hidden_states)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs: object,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    if intermediate_tensors is not None:
        inputs_embeds = None

    hidden_states = self.language_model.model(
        input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
    )
    return hidden_states

get_input_embeddings

get_input_embeddings(
    input_ids: Tensor,
    multimodal_embeddings: Optional[
        MultiModalEmbeddings
    ] = None,
    *,
    is_multimodal: Optional[Tensor] = None,
    handle_oov_mm_token: bool = False,
) -> Tensor
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_input_embeddings(
    self,
    input_ids: torch.Tensor,
    multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    *,
    is_multimodal: Optional[torch.Tensor] = None,
    handle_oov_mm_token: bool = False,
) -> torch.Tensor:
    # This is to satisfy the type checker for each overload
    if multimodal_embeddings is None or is_multimodal is None:
        return super().get_input_embeddings(input_ids)

    return super().get_input_embeddings(
        input_ids,
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_multimodal,
        handle_oov_mm_token=handle_oov_mm_token,
    )

get_language_model

get_language_model() -> Module
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_language_model(self) -> torch.nn.Module:
    return self.language_model

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="merger.",
        tower_model=["visual.", "audio_tower."],
    )

get_mrope_input_positions classmethod

get_mrope_input_positions(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    second_per_grid_ts: Optional[list[float]] = None,
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[Tensor, int]

Get mrope input positions and delta value (Qwen2.5-Omni version).

Differences from MRotaryEmbedding
  1. Add audio support (and related audio_feature_lengths).
  2. Add use_audio_in_video option to read audio from video inputs. In this case, audio and vision position ids will be split into chunks and interleaved.

Example:

(V_i are vision position ids, A_i are audio position ids)

|V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
@classmethod
def get_mrope_input_positions(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    second_per_grid_ts: Optional[list[float]] = None,
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[torch.Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
    """Get mrope input positions and delta value (Qwen2.5-Omni version).

    Differences from MRotaryEmbedding:
        1. Add audio support (and related `audio_feature_lengths`).
        2. Add `use_audio_in_video` option to read audio from video inputs.
            In this case, audio and vision position ids will be split into
            chunks and interleaved.

    Example:

        (V_i are vision position ids, A_i are audio position ids)

        |V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
        |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
    """

    # TODO(fyabc): refactor and share more code with
    #  _vl_get_input_positions_tensor.

    thinker_config = hf_config.thinker_config
    audio_token_id = thinker_config.audio_token_index
    image_token_id = thinker_config.image_token_index
    video_token_id = thinker_config.video_token_index
    audio_start_token_id = thinker_config.audio_start_token_id
    audio_end_token_id = thinker_config.audio_end_token_id
    vision_start_token_id = thinker_config.vision_start_token_id
    vision_end_token_id = thinker_config.vision_end_token_id
    seconds_per_chunk = thinker_config.seconds_per_chunk
    spatial_merge_size = thinker_config.vision_config.spatial_merge_size
    tokens_per_second = getattr(
        thinker_config.vision_config, "tokens_per_second", 25
    )

    if isinstance(image_grid_thw, list):
        image_grid_thw = torch.tensor(image_grid_thw)
    if isinstance(video_grid_thw, list):
        video_grid_thw = torch.tensor(video_grid_thw)

    src_item = input_tokens
    audio_seqlens = audio_feature_lengths
    if not second_per_grid_ts:
        second_per_grid_ts = [1] * video_grid_thw.shape[0]
    audio_idx = 0
    video_idx = 0
    image_idx = 0
    new_src_item: list[int] = []
    llm_pos_ids_list: list[torch.Tensor] = []

    idx = 0
    while idx < len(src_item):
        new_src_item_len = len(new_src_item)
        start_idx = (
            llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
        )
        if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]:
            if use_audio_in_video and idx > 0:
                if (
                    src_item[idx] == vision_end_token_id
                    and src_item[idx - 1] == audio_end_token_id
                ):
                    # processing the <|audio_eos|> before <|vision_eos|>
                    start_idx -= 1
                elif (
                    src_item[idx] == audio_start_token_id
                    and src_item[idx - 1] == vision_start_token_id
                ):
                    # processing the <|audio_bos|> after <|vision_eos|>
                    start_idx -= 1
            new_src_item.append(src_item[idx])
            llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
            llm_pos_ids_list.append(llm_pos_ids)
        elif src_item[idx] == audio_token_id:
            assert audio_seqlens is not None
            audio_seqlen = audio_seqlens[audio_idx]
            place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
            new_src_item.extend([audio_token_id] * place_num)
            llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
            llm_pos_ids_list.append(llm_pos_ids)
            audio_idx += 1
        elif src_item[idx] == image_token_id:
            grid_t = image_grid_thw[image_idx][0]
            grid_hs = image_grid_thw[:, 1]
            grid_ws = image_grid_thw[:, 2]
            t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
            llm_pos_ids = get_llm_pos_ids_for_vision(
                start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
            )
            llm_pos_ids_list.append(llm_pos_ids)
            vision_seqlen = image_grid_thw[image_idx].prod() // (
                spatial_merge_size**2
            )
            new_src_item.extend([image_token_id] * vision_seqlen)
            image_idx += 1
        elif src_item[idx] == video_token_id and not use_audio_in_video:
            grid_t = video_grid_thw[video_idx][0]
            grid_hs = video_grid_thw[:, 1]
            grid_ws = video_grid_thw[:, 2]
            t_index = (
                torch.arange(grid_t)
                * second_per_grid_ts[video_idx]
                * tokens_per_second
            ).long()
            llm_pos_ids = get_llm_pos_ids_for_vision(
                start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
            )
            llm_pos_ids_list.append(llm_pos_ids)
            vision_seqlen = video_grid_thw[video_idx].prod() // (
                spatial_merge_size**2
            )
            new_src_item.extend([video_token_id] * vision_seqlen)
            video_idx += 1
        else:
            # read audio from video
            assert audio_seqlens is not None
            audio_seqlen = audio_seqlens[audio_idx]
            vision_seqlen = video_grid_thw[video_idx].prod() // (
                spatial_merge_size**2
            )
            grid_t = video_grid_thw[video_idx][0]
            grid_h = video_grid_thw[video_idx][1]
            grid_w = video_grid_thw[video_idx][2]
            grid_hs = video_grid_thw[:, 1]
            grid_ws = video_grid_thw[:, 2]
            t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
            t_index = (
                torch.arange(grid_t)
                * second_per_grid_ts[video_idx]
                * tokens_per_second
            ).long()
            t_index_split_chunk = split_list_into_ranges(
                t_index, t_ntoken_per_chunk
            )
            place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
            pure_audio_len = place_num - 2
            added_audio_len = 0
            audio_llm_pos_ids_list: list[torch.Tensor] = []
            for t_chunk in t_index_split_chunk:
                vision_ntoken_per_chunk = (
                    len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
                )
                new_src_item.extend([video_token_id] * vision_ntoken_per_chunk)
                vision_llm_pos_ids_list = get_llm_pos_ids_for_vision(
                    start_idx,
                    video_idx,
                    spatial_merge_size,
                    t_chunk,
                    grid_hs,
                    grid_ws,
                ).split(1, dim=1)
                llm_pos_ids_list.extend(vision_llm_pos_ids_list)
                new_src_item.extend(
                    min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
                    * [audio_token_id]
                )
                audio_start_idx = (
                    start_idx
                    if len(audio_llm_pos_ids_list) == 0
                    else audio_llm_pos_ids_list[-1][0].item() + 1
                )
                if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0:
                    audio_llm_pos_ids_list = (
                        torch.arange(
                            min(
                                t_ntoken_per_chunk, pure_audio_len - added_audio_len
                            )
                        ).expand(3, -1)
                        + audio_start_idx
                    ).split(1, dim=1)
                else:
                    audio_llm_pos_ids_list = []
                added_audio_len += min(
                    t_ntoken_per_chunk, pure_audio_len - added_audio_len
                )
                llm_pos_ids_list.extend(audio_llm_pos_ids_list)
            if added_audio_len < pure_audio_len:
                new_src_item.extend(
                    (pure_audio_len - added_audio_len) * [audio_token_id]
                )
                audio_llm_pos_ids_list = (
                    torch.arange(pure_audio_len - added_audio_len).expand(3, -1)
                    + llm_pos_ids_list[-1].max()
                    + 1
                ).split(1, dim=1)
                llm_pos_ids_list.extend(audio_llm_pos_ids_list)
            audio_idx += 1
            video_idx += 1
        # move to the next token
        idx += len(new_src_item) - new_src_item_len

    llm_positions = torch.cat(llm_pos_ids_list, dim=1)
    mrope_position_delta = (
        torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
    )
    llm_positions = llm_positions[:, context_len:seq_len]

    return llm_positions, mrope_position_delta

get_multimodal_embeddings

get_multimodal_embeddings(
    **kwargs: object,
) -> MultiModalEmbeddings
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
    mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
    if not mm_input_by_modality:
        return []

    # The result multimodal_embeddings is tuple of tensors, with each
    # tensor corresponding to a multimodal data item (image or video).
    multimodal_embeddings: tuple[torch.Tensor, ...] = ()

    # NOTE: It is important to iterate over the keys in this dictionary
    # to preserve the order of the modalities.
    for modality in mm_input_by_modality:
        multimodal_input = mm_input_by_modality[modality]
        if modality == "image":
            vision_embeddings = self._process_image_input(multimodal_input)
            multimodal_embeddings += vision_embeddings
        if modality == "video":
            video_embeddings = self._process_video_input(multimodal_input)
            multimodal_embeddings += video_embeddings
        if modality == "audio":
            audio_embeddings = self._process_audio_input(multimodal_input)
            multimodal_embeddings += audio_embeddings
    return multimodal_embeddings

get_multimodal_embeddings_v0

get_multimodal_embeddings_v0(
    **kwargs: object,
) -> Optional[NestedTensors]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTensors]:
    audio_input = self._parse_and_validate_audio_input(**kwargs)
    image_input = self._parse_and_validate_image_input(**kwargs)
    video_input = self._parse_and_validate_video_input(**kwargs)

    if audio_input is None and image_input is None and video_input is None:
        return None

    multimodal_embeddings: list[tuple[NestedTensors, str]] = []

    if audio_input is not None:
        audio_embeds = self._process_audio_input(audio_input)
        multimodal_embeddings.append((audio_embeds, "audio"))
    if image_input is not None:
        image_embeds = self._process_image_input(image_input)
        multimodal_embeddings.append((image_embeds, "image"))
    if video_input is not None:
        video_embeds = self._process_video_input(video_input)
        multimodal_embeddings.append((video_embeds, "video"))
    return multimodal_embeddings

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> Optional[str]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
    if modality.startswith("image"):
        return "<|vision_start|><|IMAGE|><|vision_end|>"
    if modality.startswith("video"):
        return "<|vision_start|><|VIDEO|><|vision_end|>"
    if modality.startswith("audio"):
        return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"

    raise ValueError("Only image, video or audio modality is supported")

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    skip_prefixes = ["talker.", "token2wav."]
    if self.audio_tower is None:
        skip_prefixes.extend(["audio_tower."])
    if self.visual is None:
        skip_prefixes.extend(["visual."])

    loader = AutoWeightsLoader(
        self,
        skip_prefixes=skip_prefixes,
    )
    loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    return loaded_weights

Qwen2_5OmniThinkerMultiModalDataParser

Bases: Qwen2VLMultiModalDataParser

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(self._spatial_merge_size, *args, **kwargs)

    def _parse_audio_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"input_audio_features", "audio_feature_lengths"},
                fields_factory=create_qwen2_5_omni_thinker_field_factory(
                    self._spatial_merge_size
                ),
            )

        return super()._parse_audio_data(data)

_spatial_merge_size instance-attribute

_spatial_merge_size = spatial_merge_size

__init__

__init__(spatial_merge_size: int, *args, **kwargs)
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def __init__(self, spatial_merge_size: int, *args, **kwargs):
    self._spatial_merge_size = spatial_merge_size
    super().__init__(self._spatial_merge_size, *args, **kwargs)

_parse_audio_data

_parse_audio_data(
    data: Union[dict[str, Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _parse_audio_data(
    self,
    data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
    if isinstance(data, dict):
        return DictEmbeddingItems(
            data,
            modality="audio",
            required_fields={"input_audio_features", "audio_feature_lengths"},
            fields_factory=create_qwen2_5_omni_thinker_field_factory(
                self._spatial_merge_size
            ),
        )

    return super()._parse_audio_data(data)

Qwen2_5OmniThinkerMultiModalProcessor

Bases: BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
class Qwen2_5OmniThinkerMultiModalProcessor(
    BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]
):
    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_feature_extractor()
        return Qwen2_5OmniThinkerMultiModalDataParser(
            spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size,
            target_sr=feature_extractor.sampling_rate,
        )

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])

        # NOTE: WhisperFeatureExtractor cannot handle empty list of audios
        if audios:
            # NOTE: Qwen2.5-Omni processor accept "audio"
            mm_data["audio"] = audios
            mm_kwargs = dict(
                **mm_kwargs,
            )

        hf_inputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )

        input_features = hf_inputs.pop("input_features", None)
        feature_attention_mask = hf_inputs.get("feature_attention_mask", None)
        if "input_audio_features" not in hf_inputs and input_features is not None:
            if feature_attention_mask is not None:
                input_features = input_features.permute(0, 2, 1)[
                    feature_attention_mask.bool()
                ].permute(1, 0)
            hf_inputs["input_audio_features"] = input_features
        if (
            "audio_feature_lengths" not in hf_inputs
            and feature_attention_mask is not None
        ):
            hf_inputs["audio_feature_lengths"] = feature_attention_mask.sum(-1)

        video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
        if video_second_per_grid is not None:
            hf_inputs["second_per_grid_ts"] = video_second_per_grid

        use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
        hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video)

        return hf_inputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return create_qwen2_5_omni_thinker_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)

    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        prompt_ids: list[int],
        mm_kwargs: MultiModalKwargsItems,
        mm_prompt_updates: MultiModalPromptUpdates,
        is_update_applied: bool,
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
        """
        Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
        """
        mm_item_counts = mm_items.get_all_counts()
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
        self._validate_mm_updates(mm_prompt_updates, mm_item_counts)

        use_audio_in_video = False
        if "video" in mm_kwargs:
            video_items = [item for item in mm_kwargs["video"] if item is not None]
            # only check video items (if there are any)
            if video_items:
                use_audio_in_video = all(
                    item["use_audio_in_video"].data for item in video_items
                )

        if is_update_applied:
            mm_placeholders = self._find_mm_placeholders(
                prompt_ids,
                mm_prompt_updates,
            )
            self._validate_mm_placeholders(
                mm_placeholders,
                mm_item_counts,
                use_audio_in_video=use_audio_in_video,
            )
        else:
            prompt_ids, mm_placeholders = self._apply_prompt_updates(
                prompt_ids,
                mm_prompt_updates,
            )
            self._validate_mm_placeholders(
                mm_placeholders,
                mm_item_counts,
                use_audio_in_video=use_audio_in_video,
            )

        return prompt_ids, mm_placeholders

    @classmethod
    def omni_get_updates_use_audio_in_video(
        cls,
        thinker_config: PretrainedConfig,
        audio_len: int,
        video_grid_thw: Union[list[int], torch.Tensor],
        video_second_per_grid_t: float,
    ) -> list[int]:
        """Get video prompt updates when `use_audio_in_video` is True.

        In this case, audio and vision update ids will be split into
        chunks and interleaved (details in `_omni_get_input_positions_tensor`).

        <|video_bos|><|VIDEO|><|video_eos|> =>
        <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
        """

        audio_token_id = thinker_config.audio_token_index
        video_token_id = thinker_config.video_token_index
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        seconds_per_chunk = thinker_config.seconds_per_chunk
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(
            thinker_config.vision_config, "tokens_per_second", 25
        )

        grid_t = video_grid_thw[0]
        grid_h = video_grid_thw[1]
        grid_w = video_grid_thw[2]
        t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
        t_index = (
            torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second
        ).long()
        t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk)

        updates = [audio_start_token_id]
        added_audio_len = 0
        for t_chunk in t_index_split_chunk:
            vision_ntoken_per_chunk = (
                len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
            )
            updates.extend([video_token_id] * vision_ntoken_per_chunk)

            audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len)
            updates.extend(audio_chunk_size * [audio_token_id])
            added_audio_len += audio_chunk_size
        if added_audio_len < audio_len:
            updates.extend((audio_len - added_audio_len) * [audio_token_id])
        updates.extend([audio_end_token_id])

        return updates

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
        vocab = tokenizer.get_vocab()

        audio_token = processor.audio_token
        image_token = processor.image_token
        video_token = processor.video_token
        audio_token_id = vocab[audio_token]
        image_token_id = vocab[image_token]
        video_token_id = vocab[video_token]

        out_mm_data = out_mm_kwargs.get_data()
        audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
        if audio_feature_lengths is None and feature_attention_mask is None:
            audio_output_lengths = []
        elif audio_feature_lengths is not None:
            _, audio_output_lens = _get_feat_extract_output_lengths(
                audio_feature_lengths
            )
            audio_output_lengths = audio_output_lens.tolist()
        elif feature_attention_mask is not None:
            assert isinstance(feature_attention_mask, torch.Tensor)
            _, audio_output_lens = _get_feat_extract_output_lengths(
                feature_attention_mask.sum(-1)
            )
            audio_output_lengths = audio_output_lens.tolist()

        # number of audios read from video.
        audio_in_video_item_idx = 0

        def get_replacement_qwen2_audio(item_idx: int):
            item_idx += audio_in_video_item_idx

            num_features = audio_output_lengths[item_idx]
            if num_features == 0:
                audios = mm_items.get_items("audio", AudioProcessorItems)
                audio = audios.get(item_idx)
                raise ValueError(
                    f"The audio {audio} (len={len(audio)}) is too short "
                    "to be represented inside the model"
                )

            return [audio_token_id] * num_features

        def get_replacement_qwen2_vision(item_idx: int, modality: str):
            grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx]
            assert isinstance(grid_thw, torch.Tensor)
            merge_length = image_processor.merge_size**2

            token_id = image_token_id if modality == "image" else video_token_id
            return [token_id] * (int(grid_thw.prod()) // merge_length)

        use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
        thinker_config = self.info.get_hf_config()

        def get_replacement_qwen2_use_audio_in_video(item_idx: int):
            nonlocal audio_in_video_item_idx

            audio_num_features = audio_output_lengths[
                audio_in_video_item_idx + item_idx
            ]
            video_grid_thw = out_mm_data["video_grid_thw"][item_idx]

            audio_in_video_item_idx += 1

            second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None)
            if second_per_grid_ts:
                video_second_per_grid_t = second_per_grid_ts[item_idx]
            else:
                video_second_per_grid_t = 1.0

            return self.omni_get_updates_use_audio_in_video(
                thinker_config=thinker_config,
                audio_len=audio_num_features,
                video_grid_thw=video_grid_thw,
                video_second_per_grid_t=video_second_per_grid_t,
            )

        video_replacement_fn = (
            get_replacement_qwen2_use_audio_in_video
            if use_audio_in_video
            else partial(get_replacement_qwen2_vision, modality="video")
        )

        return [
            PromptReplacement(
                modality="audio",
                target=audio_token,
                replacement=get_replacement_qwen2_audio,
            ),
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=partial(get_replacement_qwen2_vision, modality="image"),
            ),
            PromptReplacement(
                modality="video",
                target=video_token,
                replacement=video_replacement_fn,
            ),
        ]

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
        *,
        enable_hf_prompt_update: bool,
    ) -> tuple[list[int], BatchFeature, bool]:
        """
        Qwen2.5-Omni reimplements this function to handle text only.
        """
        if isinstance(prompt, str):
            if enable_hf_prompt_update:
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                    tokenization_kwargs=tokenization_kwargs,
                )
            tokenizer = self.info.get_tokenizer()
            prompt_ids = encode_tokens(tokenizer, prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

        mm_processed_data = self._apply_hf_processor_mm_only(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
        )

        return prompt_ids, mm_processed_data, False

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        """
        Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
        """
        mm_counts = mm_items.get_all_counts()

        use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
        if use_audio_in_video and "video" in mm_counts:
            assert "audio" in mm_counts
            mm_counts["audio"] -= mm_counts["video"]

        _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
        )

        return mm_processed_data

    def _validate_mm_placeholders(
        self,
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
        mm_item_counts: Mapping[str, int],
        use_audio_in_video: bool = False,
    ) -> None:
        if use_audio_in_video:
            mm_item_counts = copy(mm_item_counts)
            if "video" in mm_item_counts:
                assert "audio" in mm_item_counts
                mm_item_counts["audio"] -= mm_item_counts["video"]
        super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)

_apply_hf_processor_main

_apply_hf_processor_main(
    prompt: Union[str, list[int]],
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object],
    *,
    enable_hf_prompt_update: bool,
) -> tuple[list[int], BatchFeature, bool]

Qwen2.5-Omni reimplements this function to handle text only.

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _apply_hf_processor_main(
    self,
    prompt: Union[str, list[int]],
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object],
    *,
    enable_hf_prompt_update: bool,
) -> tuple[list[int], BatchFeature, bool]:
    """
    Qwen2.5-Omni reimplements this function to handle text only.
    """
    if isinstance(prompt, str):
        if enable_hf_prompt_update:
            return self._apply_hf_processor_text_mm(
                prompt_text=prompt,
                mm_items=mm_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                tokenization_kwargs=tokenization_kwargs,
            )
        tokenizer = self.info.get_tokenizer()
        prompt_ids = encode_tokens(tokenizer, prompt)
    else:
        prompt_ids = self._apply_hf_processor_tokens_only(prompt)

    mm_processed_data = self._apply_hf_processor_mm_only(
        mm_items=mm_items,
        hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        tokenization_kwargs=tokenization_kwargs,
    )

    return prompt_ids, mm_processed_data, False

_apply_hf_processor_mm_only

_apply_hf_processor_mm_only(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object],
) -> BatchFeature

Qwen2.5-Omni reimplements this function to handle use_audio_in_video.

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _apply_hf_processor_mm_only(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object],
) -> BatchFeature:
    """
    Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
    """
    mm_counts = mm_items.get_all_counts()

    use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
    if use_audio_in_video and "video" in mm_counts:
        assert "audio" in mm_counts
        mm_counts["audio"] -= mm_counts["video"]

    _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
        prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
        mm_items=mm_items,
        hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        tokenization_kwargs=tokenization_kwargs,
    )

    return mm_processed_data

_call_hf_processor

_call_hf_processor(
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _call_hf_processor(
    self,
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature:
    mm_data = dict(mm_data)
    audios = mm_data.pop("audios", [])

    # NOTE: WhisperFeatureExtractor cannot handle empty list of audios
    if audios:
        # NOTE: Qwen2.5-Omni processor accept "audio"
        mm_data["audio"] = audios
        mm_kwargs = dict(
            **mm_kwargs,
        )

    hf_inputs = super()._call_hf_processor(
        prompt=prompt,
        mm_data=mm_data,
        mm_kwargs=mm_kwargs,
        tok_kwargs=tok_kwargs,
    )

    input_features = hf_inputs.pop("input_features", None)
    feature_attention_mask = hf_inputs.get("feature_attention_mask", None)
    if "input_audio_features" not in hf_inputs and input_features is not None:
        if feature_attention_mask is not None:
            input_features = input_features.permute(0, 2, 1)[
                feature_attention_mask.bool()
            ].permute(1, 0)
        hf_inputs["input_audio_features"] = input_features
    if (
        "audio_feature_lengths" not in hf_inputs
        and feature_attention_mask is not None
    ):
        hf_inputs["audio_feature_lengths"] = feature_attention_mask.sum(-1)

    video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
    if video_second_per_grid is not None:
        hf_inputs["second_per_grid_ts"] = video_second_per_grid

    use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
    hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video)

    return hf_inputs

_get_data_parser

_get_data_parser() -> MultiModalDataParser
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _get_data_parser(self) -> MultiModalDataParser:
    feature_extractor = self.info.get_feature_extractor()
    return Qwen2_5OmniThinkerMultiModalDataParser(
        spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size,
        target_sr=feature_extractor.sampling_rate,
    )

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return create_qwen2_5_omni_thinker_field_factory(
        self.info.get_hf_config().vision_config.spatial_merge_size
    )(hf_inputs)

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, Any],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, Any],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
    processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
    tokenizer = self.info.get_tokenizer()
    image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
    vocab = tokenizer.get_vocab()

    audio_token = processor.audio_token
    image_token = processor.image_token
    video_token = processor.video_token
    audio_token_id = vocab[audio_token]
    image_token_id = vocab[image_token]
    video_token_id = vocab[video_token]

    out_mm_data = out_mm_kwargs.get_data()
    audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
    feature_attention_mask = out_mm_data.get("feature_attention_mask")
    if audio_feature_lengths is None and feature_attention_mask is None:
        audio_output_lengths = []
    elif audio_feature_lengths is not None:
        _, audio_output_lens = _get_feat_extract_output_lengths(
            audio_feature_lengths
        )
        audio_output_lengths = audio_output_lens.tolist()
    elif feature_attention_mask is not None:
        assert isinstance(feature_attention_mask, torch.Tensor)
        _, audio_output_lens = _get_feat_extract_output_lengths(
            feature_attention_mask.sum(-1)
        )
        audio_output_lengths = audio_output_lens.tolist()

    # number of audios read from video.
    audio_in_video_item_idx = 0

    def get_replacement_qwen2_audio(item_idx: int):
        item_idx += audio_in_video_item_idx

        num_features = audio_output_lengths[item_idx]
        if num_features == 0:
            audios = mm_items.get_items("audio", AudioProcessorItems)
            audio = audios.get(item_idx)
            raise ValueError(
                f"The audio {audio} (len={len(audio)}) is too short "
                "to be represented inside the model"
            )

        return [audio_token_id] * num_features

    def get_replacement_qwen2_vision(item_idx: int, modality: str):
        grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx]
        assert isinstance(grid_thw, torch.Tensor)
        merge_length = image_processor.merge_size**2

        token_id = image_token_id if modality == "image" else video_token_id
        return [token_id] * (int(grid_thw.prod()) // merge_length)

    use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
    thinker_config = self.info.get_hf_config()

    def get_replacement_qwen2_use_audio_in_video(item_idx: int):
        nonlocal audio_in_video_item_idx

        audio_num_features = audio_output_lengths[
            audio_in_video_item_idx + item_idx
        ]
        video_grid_thw = out_mm_data["video_grid_thw"][item_idx]

        audio_in_video_item_idx += 1

        second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None)
        if second_per_grid_ts:
            video_second_per_grid_t = second_per_grid_ts[item_idx]
        else:
            video_second_per_grid_t = 1.0

        return self.omni_get_updates_use_audio_in_video(
            thinker_config=thinker_config,
            audio_len=audio_num_features,
            video_grid_thw=video_grid_thw,
            video_second_per_grid_t=video_second_per_grid_t,
        )

    video_replacement_fn = (
        get_replacement_qwen2_use_audio_in_video
        if use_audio_in_video
        else partial(get_replacement_qwen2_vision, modality="video")
    )

    return [
        PromptReplacement(
            modality="audio",
            target=audio_token,
            replacement=get_replacement_qwen2_audio,
        ),
        PromptReplacement(
            modality="image",
            target=image_token,
            replacement=partial(get_replacement_qwen2_vision, modality="image"),
        ),
        PromptReplacement(
            modality="video",
            target=video_token,
            replacement=video_replacement_fn,
        ),
    ]

_maybe_apply_prompt_updates

_maybe_apply_prompt_updates(
    mm_items: MultiModalDataItems,
    prompt_ids: list[int],
    mm_kwargs: MultiModalKwargsItems,
    mm_prompt_updates: MultiModalPromptUpdates,
    is_update_applied: bool,
) -> tuple[
    list[int], Mapping[str, list[PlaceholderFeaturesInfo]]
]

Qwen2.5-Omni reimplements this function to handle use_audio_in_video.

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _maybe_apply_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    prompt_ids: list[int],
    mm_kwargs: MultiModalKwargsItems,
    mm_prompt_updates: MultiModalPromptUpdates,
    is_update_applied: bool,
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
    """
    Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
    """
    mm_item_counts = mm_items.get_all_counts()
    self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
    self._validate_mm_updates(mm_prompt_updates, mm_item_counts)

    use_audio_in_video = False
    if "video" in mm_kwargs:
        video_items = [item for item in mm_kwargs["video"] if item is not None]
        # only check video items (if there are any)
        if video_items:
            use_audio_in_video = all(
                item["use_audio_in_video"].data for item in video_items
            )

    if is_update_applied:
        mm_placeholders = self._find_mm_placeholders(
            prompt_ids,
            mm_prompt_updates,
        )
        self._validate_mm_placeholders(
            mm_placeholders,
            mm_item_counts,
            use_audio_in_video=use_audio_in_video,
        )
    else:
        prompt_ids, mm_placeholders = self._apply_prompt_updates(
            prompt_ids,
            mm_prompt_updates,
        )
        self._validate_mm_placeholders(
            mm_placeholders,
            mm_item_counts,
            use_audio_in_video=use_audio_in_video,
        )

    return prompt_ids, mm_placeholders

_validate_mm_placeholders

_validate_mm_placeholders(
    mm_placeholders: Mapping[
        str, list[PlaceholderFeaturesInfo]
    ],
    mm_item_counts: Mapping[str, int],
    use_audio_in_video: bool = False,
) -> None
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def _validate_mm_placeholders(
    self,
    mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
    mm_item_counts: Mapping[str, int],
    use_audio_in_video: bool = False,
) -> None:
    if use_audio_in_video:
        mm_item_counts = copy(mm_item_counts)
        if "video" in mm_item_counts:
            assert "audio" in mm_item_counts
            mm_item_counts["audio"] -= mm_item_counts["video"]
    super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)

omni_get_updates_use_audio_in_video classmethod

omni_get_updates_use_audio_in_video(
    thinker_config: PretrainedConfig,
    audio_len: int,
    video_grid_thw: Union[list[int], Tensor],
    video_second_per_grid_t: float,
) -> list[int]

Get video prompt updates when use_audio_in_video is True.

In this case, audio and vision update ids will be split into chunks and interleaved (details in _omni_get_input_positions_tensor).

<|video_bos|><|VIDEO|><|video_eos|> => <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
@classmethod
def omni_get_updates_use_audio_in_video(
    cls,
    thinker_config: PretrainedConfig,
    audio_len: int,
    video_grid_thw: Union[list[int], torch.Tensor],
    video_second_per_grid_t: float,
) -> list[int]:
    """Get video prompt updates when `use_audio_in_video` is True.

    In this case, audio and vision update ids will be split into
    chunks and interleaved (details in `_omni_get_input_positions_tensor`).

    <|video_bos|><|VIDEO|><|video_eos|> =>
    <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
    """

    audio_token_id = thinker_config.audio_token_index
    video_token_id = thinker_config.video_token_index
    audio_start_token_id = thinker_config.audio_start_token_id
    audio_end_token_id = thinker_config.audio_end_token_id
    seconds_per_chunk = thinker_config.seconds_per_chunk
    spatial_merge_size = thinker_config.vision_config.spatial_merge_size
    tokens_per_second = getattr(
        thinker_config.vision_config, "tokens_per_second", 25
    )

    grid_t = video_grid_thw[0]
    grid_h = video_grid_thw[1]
    grid_w = video_grid_thw[2]
    t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
    t_index = (
        torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second
    ).long()
    t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk)

    updates = [audio_start_token_id]
    added_audio_len = 0
    for t_chunk in t_index_split_chunk:
        vision_ntoken_per_chunk = (
            len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
        )
        updates.extend([video_token_id] * vision_ntoken_per_chunk)

        audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len)
        updates.extend(audio_chunk_size * [audio_token_id])
        added_audio_len += audio_chunk_size
    if added_audio_len < audio_len:
        updates.extend((audio_len - added_audio_len) * [audio_token_id])
    updates.extend([audio_end_token_id])

    return updates

Qwen2_5OmniThinkerProcessingInfo

Bases: Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo

Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
class Qwen2_5OmniThinkerProcessingInfo(
    Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo
):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config

    def get_hf_processor(self, **kwargs: object) -> Qwen2_5OmniProcessor:
        return self.ctx.get_hf_processor(
            Qwen2_5OmniProcessor,
            use_fast=kwargs.pop("use_fast", True),
            **kwargs,
        )

    def get_feature_extractor(self, **kwargs: object):
        hf_processor = self.get_hf_processor(**kwargs)
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None, "image": None, "video": None}

get_feature_extractor

get_feature_extractor(**kwargs: object)
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_feature_extractor(self, **kwargs: object):
    hf_processor = self.get_hf_processor(**kwargs)
    feature_extractor = hf_processor.feature_extractor  # type: ignore
    assert isinstance(feature_extractor, WhisperFeatureExtractor)
    return feature_extractor

get_hf_config

get_hf_config()
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_hf_config(self):
    return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config

get_hf_processor

get_hf_processor(**kwargs: object) -> Qwen2_5OmniProcessor
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_hf_processor(self, **kwargs: object) -> Qwen2_5OmniProcessor:
    return self.ctx.get_hf_processor(
        Qwen2_5OmniProcessor,
        use_fast=kwargs.pop("use_fast", True),
        **kwargs,
    )

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, Optional[int]]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
    return {"audio": None, "image": None, "video": None}

create_qwen2_5_omni_thinker_field_factory

create_qwen2_5_omni_thinker_field_factory(
    spatial_merge_size: int,
) -> Callable[
    [Mapping[str, Tensor]],
    Mapping[str, MultiModalFieldConfig],
]
Source code in vllm/model_executor/models/qwen2_5_omni_thinker.py
def create_qwen2_5_omni_thinker_field_factory(
    spatial_merge_size: int,
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, MultiModalFieldConfig]]:
    def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        audio_feature_lengths = hf_inputs.get(
            "audio_feature_lengths", torch.empty((0,))
        )

        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )

        num_videos = len(video_grid_sizes)

        return dict(
            input_audio_features=MultiModalFieldConfig.flat_from_sizes(
                "audio", audio_feature_lengths, dim=1
            ),
            feature_attention_mask=MultiModalFieldConfig.batched("audio"),
            audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", image_pixel_grid_sizes
            ),
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
                "image", image_embed_grid_sizes
            ),
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
                "video", video_grid_sizes
            ),
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
                "video", video_embed_grid_sizes
            ),
            video_grid_thw=MultiModalFieldConfig.batched("video"),
            second_per_grid_ts=MultiModalFieldConfig.batched("video"),
            use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
        )

    return _qwen2_5_omni_thinker_field_config