mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00

* VLMs can work with embeds now * update more models * fix tests * fix copies * fixup * fix * style * unskip tests * fix copies * fix tests * style * omni modality models * qwen models had extra indentation * fix some other tests * fix copies * fix test last time * unrelated changes revert * we can't rely only on embeds * delete file * de-flake mistral3 * fix qwen models * fix style * fix tests * fix copies * deflake the test * modular reverted by fixes, fix again * flaky test, overwritten * fix copies * style
525 lines
22 KiB
Python
525 lines
22 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch Llava model."""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...generation import GenerationMixin
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
|
from ..auto import AutoModel
|
|
from .configuration_llava import LlavaConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Llava outputs, with hidden states and attentions.
|
|
"""
|
|
)
|
|
class LlavaModelOutputWithPast(BaseModelOutputWithPast):
|
|
r"""
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
|
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
|
"""
|
|
|
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Llava causal language model (or autoregressive) outputs.
|
|
"""
|
|
)
|
|
class LlavaCausalLMOutputWithPast(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss (for next-token prediction).
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
|
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
logits: Optional[torch.FloatTensor] = None
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
class LlavaMultiModalProjector(nn.Module):
|
|
def __init__(self, config: LlavaConfig):
|
|
super().__init__()
|
|
# We have hidden_size * the number of vision feature layers
|
|
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
|
|
self.linear_1 = nn.Linear(
|
|
config.vision_config.hidden_size * num_feature_layers,
|
|
config.text_config.hidden_size,
|
|
bias=config.multimodal_projector_bias,
|
|
)
|
|
self.act = ACT2FN[config.projector_hidden_act]
|
|
self.linear_2 = nn.Linear(
|
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
|
)
|
|
|
|
def forward(self, image_features):
|
|
hidden_states = self.linear_1(image_features)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
@auto_docstring
|
|
class LlavaPreTrainedModel(PreTrainedModel):
|
|
config_class = LlavaConfig
|
|
base_model_prefix = ""
|
|
supports_gradient_checkpointing = True
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_cache_class = True
|
|
_supports_flash_attn_2 = True
|
|
_supports_sdpa = True
|
|
_supports_quantized_cache = True
|
|
_supports_static_cache = True
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
|
|
def _init_weights(self, module):
|
|
# important: this ported version of Llava isn't meant for training from scratch - only
|
|
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
|
# https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
|
|
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
|
|
|
if isinstance(module, nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.weight.data.fill_(1.0)
|
|
module.bias.data.zero_()
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The Llava model which consists of a vision backbone and a language model, without a language modeling head.
|
|
"""
|
|
)
|
|
class LlavaModel(LlavaPreTrainedModel):
|
|
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
|
|
|
def __init__(self, config: LlavaConfig):
|
|
super().__init__(config)
|
|
self.vision_tower = AutoModel.from_config(config.vision_config)
|
|
|
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
|
self.language_model = AutoModel.from_config(config.text_config)
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.language_model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.language_model.set_input_embeddings(value)
|
|
|
|
def set_decoder(self, decoder):
|
|
self.language_model = decoder
|
|
|
|
def get_decoder(self):
|
|
return self.language_model
|
|
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
vision_feature_select_strategy: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
|
|
|
Args:
|
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
|
The tensors corresponding to the input images.
|
|
vision_feature_layer (`Union[int, list[int]]`, *optional*):
|
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
|
the vision feature of the corresponding indices will be concatenated to form the
|
|
vision features.
|
|
vision_feature_select_strategy (`str`, *optional*):
|
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
|
Can be one of `"default"` or `"full"`
|
|
Returns:
|
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
|
"""
|
|
vision_feature_layer = (
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
)
|
|
vision_feature_select_strategy = (
|
|
vision_feature_select_strategy
|
|
if vision_feature_select_strategy is not None
|
|
else self.config.vision_feature_select_strategy
|
|
)
|
|
|
|
if vision_feature_select_strategy not in ["default", "full"]:
|
|
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
|
|
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
|
|
|
|
# If we have one vision feature layer, return the corresponding hidden states,
|
|
# otherwise, select the hidden states of each feature layer and concatenate them
|
|
if isinstance(vision_feature_layer, int):
|
|
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
|
if vision_feature_select_strategy == "default":
|
|
selected_image_feature = selected_image_feature[:, 1:]
|
|
else:
|
|
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
|
# For default; crop CLS from each hidden state in the hidden state pool
|
|
if vision_feature_select_strategy == "default":
|
|
hs_pool = [hs[:, 1:] for hs in hs_pool]
|
|
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
|
|
|
image_features = self.multi_modal_projector(selected_image_feature)
|
|
|
|
if "image_sizes" in kwargs:
|
|
split_sizes = [
|
|
(height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size)
|
|
for height, width in kwargs["image_sizes"]
|
|
]
|
|
image_features = torch.split(image_features.squeeze(0), split_sizes)
|
|
else:
|
|
image_features = list(image_features)
|
|
return image_features
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
vision_feature_select_strategy: Optional[str] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
image_sizes: torch.Tensor = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> Union[tuple, LlavaModelOutputWithPast]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
vision_feature_layer = (
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
)
|
|
vision_feature_select_strategy = (
|
|
vision_feature_select_strategy
|
|
if vision_feature_select_strategy is not None
|
|
else self.config.vision_feature_select_strategy
|
|
)
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
if pixel_values is not None:
|
|
image_features = self.get_image_features(
|
|
pixel_values=pixel_values,
|
|
vision_feature_layer=vision_feature_layer,
|
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
image_sizes=image_sizes,
|
|
)
|
|
image_features = torch.cat(image_features, dim=0)
|
|
|
|
if input_ids is None:
|
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
special_image_mask = special_image_mask.all(-1)
|
|
else:
|
|
special_image_mask = input_ids == self.config.image_token_id
|
|
|
|
n_image_tokens = (special_image_mask).sum()
|
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
|
|
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
raise ValueError(
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
)
|
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
|
|
outputs = self.language_model(
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
return LlavaModelOutputWithPast(
|
|
last_hidden_state=outputs.last_hidden_state,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
image_hidden_states=image_features if pixel_values is not None else None,
|
|
)
|
|
|
|
|
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The LLAVA model which consists of a vision backbone and a language model.
|
|
"""
|
|
)
|
|
class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
|
_checkpoint_conversion_mapping = {
|
|
"^language_model.model": "model.language_model",
|
|
"^vision_tower": "model.vision_tower",
|
|
"^multi_modal_projector": "model.multi_modal_projector",
|
|
"^language_model.lm_head": "lm_head",
|
|
}
|
|
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
def __init__(self, config: LlavaConfig):
|
|
super().__init__(config)
|
|
self.model = LlavaModel(config)
|
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.model.set_input_embeddings(value)
|
|
|
|
def get_output_embeddings(self) -> nn.Module:
|
|
return self.lm_head
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.lm_head = new_embeddings
|
|
|
|
def set_decoder(self, decoder):
|
|
self.model.set_decoder(decoder)
|
|
|
|
def get_decoder(self):
|
|
return self.model.get_decoder
|
|
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
vision_feature_select_strategy: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
return self.model.get_image_features(
|
|
pixel_values=pixel_values,
|
|
vision_feature_layer=vision_feature_layer,
|
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
**kwargs,
|
|
)
|
|
|
|
# Make modules available throught conditional class for BC
|
|
@property
|
|
def language_model(self):
|
|
return self.model.language_model
|
|
|
|
@property
|
|
def vision_tower(self):
|
|
return self.model.vision_tower
|
|
|
|
@property
|
|
def multi_modal_projector(self):
|
|
return self.model.multi_modal_projector
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
vision_feature_select_strategy: Optional[str] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
image_sizes: Optional[torch.Tensor] = None,
|
|
**kwargs: Unpack[KwargsForCausalLM],
|
|
) -> Union[tuple, LlavaCausalLMOutputWithPast]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
|
|
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
|
|
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
|
```"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
vision_feature_layer = (
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
)
|
|
vision_feature_select_strategy = (
|
|
vision_feature_select_strategy
|
|
if vision_feature_select_strategy is not None
|
|
else self.config.vision_feature_select_strategy
|
|
)
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
vision_feature_layer=vision_feature_layer,
|
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
cache_position=cache_position,
|
|
image_sizes=image_sizes,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(
|
|
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
|
)
|
|
|
|
return LlavaCausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
image_hidden_states=outputs.image_hidden_states,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=None,
|
|
pixel_values=None,
|
|
attention_mask=None,
|
|
cache_position=None,
|
|
logits_to_keep=None,
|
|
**kwargs,
|
|
):
|
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
|
|
model_inputs = super().prepare_inputs_for_generation(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
logits_to_keep=logits_to_keep,
|
|
**kwargs,
|
|
)
|
|
|
|
if cache_position[0] == 0:
|
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
|
# Otherwise we need pixel values to be passed to model
|
|
model_inputs["pixel_values"] = pixel_values
|
|
|
|
return model_inputs
|
|
|
|
|
|
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]
|