mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[bug] fix llava processor to calculate unpadding size correctly (#37988)
* fix llava processor to calculate unpad size correctly * repo consistency * Revert "repo consistency" & "setUp in llava family" This reverts commit26a50af8db
. * add edge case test for padding & unpadding * compute unpadding size from original size * make test config explicit * Revert "compute unpadding size from original size" This reverts commit752cd27ad9
. * Revert "add edge case test for padding & unpadding" This reverts commitccbd094d69
. * revert unpad logic * remove irrelevant tests * model test * remove processor from model test --------- Co-authored-by: jaycha <jaycha@ncsoft.com>
This commit is contained in:
parent
67b3d45eb6
commit
a5cc7a67d7
@ -20,7 +20,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@ -133,14 +132,14 @@ def unpad_image(tensor, original_size):
|
|||||||
|
|
||||||
if original_aspect_ratio > current_aspect_ratio:
|
if original_aspect_ratio > current_aspect_ratio:
|
||||||
scale_factor = current_width / original_width
|
scale_factor = current_width / original_width
|
||||||
new_height = min(math.ceil(original_height * scale_factor), current_height)
|
new_height = int(round(original_height * scale_factor, 7))
|
||||||
padding, r = divmod(current_height - new_height, 2)
|
padding = (current_height - new_height) // 2
|
||||||
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
|
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||||
else:
|
else:
|
||||||
scale_factor = current_height / original_height
|
scale_factor = current_height / original_height
|
||||||
new_width = min(math.ceil(original_width * scale_factor), current_width)
|
new_width = int(round(original_width * scale_factor, 7))
|
||||||
padding, r = divmod(current_width - new_width, 2)
|
padding = (current_width - new_width) // 2
|
||||||
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
|
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||||
|
|
||||||
return unpadded_tensor
|
return unpadded_tensor
|
||||||
|
|
||||||
|
@ -304,14 +304,14 @@ def unpad_image(tensor, original_size):
|
|||||||
|
|
||||||
if original_aspect_ratio > current_aspect_ratio:
|
if original_aspect_ratio > current_aspect_ratio:
|
||||||
scale_factor = current_width / original_width
|
scale_factor = current_width / original_width
|
||||||
new_height = min(math.ceil(original_height * scale_factor), current_height)
|
new_height = int(round(original_height * scale_factor, 7))
|
||||||
padding, r = divmod(current_height - new_height, 2)
|
padding = (current_height - new_height) // 2
|
||||||
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
|
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||||
else:
|
else:
|
||||||
scale_factor = current_height / original_height
|
scale_factor = current_height / original_height
|
||||||
new_width = min(math.ceil(original_width * scale_factor), current_width)
|
new_width = int(round(original_width * scale_factor, 7))
|
||||||
padding, r = divmod(current_width - new_width, 2)
|
padding = (current_width - new_width) // 2
|
||||||
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
|
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||||
|
|
||||||
return unpadded_tensor
|
return unpadded_tensor
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers.models.llava_next.modeling_llava_next import (
|
from transformers.models.llava_next.modeling_llava_next import (
|
||||||
|
@ -644,7 +644,7 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
|||||||
image,
|
image,
|
||||||
image_grid_pinpoints,
|
image_grid_pinpoints,
|
||||||
size=size_tuple,
|
size=size_tuple,
|
||||||
patch_size=size["height"],
|
patch_size=size_tuple[0],
|
||||||
resample=resample,
|
resample=resample,
|
||||||
data_format=input_data_format,
|
data_format=input_data_format,
|
||||||
input_data_format=input_data_format,
|
input_data_format=input_data_format,
|
||||||
|
@ -284,14 +284,14 @@ def unpad_image(tensor, original_size):
|
|||||||
|
|
||||||
if original_aspect_ratio > current_aspect_ratio:
|
if original_aspect_ratio > current_aspect_ratio:
|
||||||
scale_factor = current_width / original_width
|
scale_factor = current_width / original_width
|
||||||
new_height = min(math.ceil(original_height * scale_factor), current_height)
|
new_height = int(round(original_height * scale_factor, 7))
|
||||||
padding, r = divmod(current_height - new_height, 2)
|
padding = (current_height - new_height) // 2
|
||||||
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
|
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||||
else:
|
else:
|
||||||
scale_factor = current_height / original_height
|
scale_factor = current_height / original_height
|
||||||
new_width = min(math.ceil(original_width * scale_factor), current_width)
|
new_width = int(round(original_width * scale_factor, 7))
|
||||||
padding, r = divmod(current_width - new_width, 2)
|
padding = (current_width - new_width) // 2
|
||||||
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
|
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||||
|
|
||||||
return unpadded_tensor
|
return unpadded_tensor
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ from ...test_modeling_common import (
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches, unpad_image
|
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@ -298,18 +298,27 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
||||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||||
|
|
||||||
def test_unpad_image(self):
|
def test_odd_sized_image(self):
|
||||||
original_size = (400, 400)
|
# prepare model configuration
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
|
||||||
# Test case width is padded
|
# prepare input
|
||||||
pixel_values = floats_tensor([3, 400, 601])
|
num_image_tokens = 24
|
||||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size])
|
||||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
|
||||||
|
input_ids[:, :num_image_tokens] = config.image_token_index
|
||||||
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||||
|
inputs_dict = {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
# Test case height is padded
|
# forward with odd-sized image input
|
||||||
pixel_values = floats_tensor([3, 503, 400])
|
for model_class in self.all_model_classes:
|
||||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
model = model_class(config).to(torch_device)
|
||||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
model(**inputs_dict)
|
||||||
|
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
[
|
[
|
||||||
|
@ -11,13 +11,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor
|
from transformers import LlamaTokenizerFast, LlavaNextProcessor
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_vision,
|
require_vision,
|
||||||
)
|
)
|
||||||
@ -52,6 +54,10 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def get_image_processor(self, **kwargs):
|
def get_image_processor(self, **kwargs):
|
||||||
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_processor_dict():
|
def prepare_processor_dict():
|
||||||
return {
|
return {
|
||||||
@ -73,13 +79,16 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
||||||
|
|
||||||
def test_image_token_filling(self):
|
def test_image_token_filling(self):
|
||||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
processor.patch_size = 14
|
processor.patch_size = 14
|
||||||
processor.vision_feature_select_strategy = "default"
|
processor.vision_feature_select_strategy = "default"
|
||||||
|
processor.image_processor.crop_size = {"height": 336, "width": 336}
|
||||||
|
processor.image_processor.size = {"shortest_edge": 336}
|
||||||
|
processor.image_processor.image_grid_pinpoints = [[672, 336]]
|
||||||
# Important to check with non square image
|
# Important to check with non square image
|
||||||
image = torch.randint(0, 2, (3, 500, 316))
|
image = torch.randint(0, 2, (3, 503, 316))
|
||||||
expected_image_tokens = 1526
|
expected_image_tokens = 1525
|
||||||
image_token_index = 32000
|
image_token_index = processor.image_token_id
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
|
@ -49,8 +49,6 @@ from ...test_modeling_common import (
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.models.llava_next_video.modeling_llava_next_video import unpad_image
|
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -314,18 +312,27 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
|||||||
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
||||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||||
|
|
||||||
def test_unpad_image(self):
|
def test_odd_sized_image(self):
|
||||||
original_size = (400, 400)
|
# prepare model configuration
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
|
||||||
# Test case width is padded
|
# prepare input
|
||||||
pixel_values = floats_tensor([3, 400, 601])
|
num_image_tokens = 24
|
||||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size])
|
||||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
|
||||||
|
input_ids[:, :num_image_tokens] = config.image_token_index
|
||||||
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||||
|
inputs_dict = {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
# Test case height is padded
|
# forward with odd-sized image input
|
||||||
pixel_values = floats_tensor([3, 503, 400])
|
for model_class in self.all_model_classes:
|
||||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
model = model_class(config).to(torch_device)
|
||||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
model(**inputs_dict)
|
||||||
|
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
[
|
[
|
||||||
|
@ -17,6 +17,8 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
|
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
|
||||||
from transformers.testing_utils import require_vision
|
from transformers.testing_utils import require_vision
|
||||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||||
@ -63,6 +65,10 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
def get_video_processor(self, **kwargs):
|
def get_video_processor(self, **kwargs):
|
||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare_processor_dict(cls):
|
def prepare_processor_dict(cls):
|
||||||
return {
|
return {
|
||||||
@ -84,6 +90,31 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
processor_dict = self.prepare_processor_dict()
|
processor_dict = self.prepare_processor_dict()
|
||||||
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
||||||
|
|
||||||
@classmethod
|
def test_image_token_filling(self):
|
||||||
def tearDownClass(cls):
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
processor.patch_size = 14
|
||||||
|
processor.vision_feature_select_strategy = "default"
|
||||||
|
processor.image_processor.crop_size = {"height": 336, "width": 336}
|
||||||
|
processor.image_processor.size = {"shortest_edge": 336}
|
||||||
|
processor.image_processor.image_grid_pinpoints = [[672, 336]]
|
||||||
|
# Important to check with non square image
|
||||||
|
image = torch.randint(0, 2, (3, 503, 316))
|
||||||
|
expected_image_tokens = 1525
|
||||||
|
image_token_index = processor.image_token_id
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
inputs = processor(
|
||||||
|
text=[processor.apply_chat_template(messages)],
|
||||||
|
images=[image],
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||||
|
self.assertEqual(expected_image_tokens, image_tokens)
|
||||||
|
@ -49,8 +49,6 @@ from ...test_modeling_common import (
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.models.llava_onevision.modeling_llava_onevision import unpad_image
|
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -268,18 +266,27 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
|||||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||||
torch.testing.assert_close(out_embeds, out_ids)
|
torch.testing.assert_close(out_embeds, out_ids)
|
||||||
|
|
||||||
def test_unpad_image(self):
|
def test_odd_sized_image(self):
|
||||||
original_size = (400, 400)
|
# prepare model configuration
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
|
||||||
# Test case width is padded
|
# prepare input
|
||||||
pixel_values = floats_tensor([3, 400, 601])
|
num_image_tokens = 10
|
||||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
pixel_values = floats_tensor([1, 2, 3, config.vision_config.image_size, config.vision_config.image_size])
|
||||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
|
||||||
|
input_ids[:, :num_image_tokens] = config.image_token_index
|
||||||
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||||
|
inputs_dict = {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
# Test case height is padded
|
# forward with odd-sized image input
|
||||||
pixel_values = floats_tensor([3, 503, 400])
|
for model_class in self.all_model_classes:
|
||||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
model = model_class(config).to(torch_device)
|
||||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
model(**inputs_dict)
|
||||||
|
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
[
|
[
|
||||||
|
@ -11,11 +11,14 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||||
|
|
||||||
@ -90,3 +93,33 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
# so we check if the same template is loaded
|
# so we check if the same template is loaded
|
||||||
processor_dict = self.prepare_processor_dict()
|
processor_dict = self.prepare_processor_dict()
|
||||||
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
||||||
|
|
||||||
|
def test_image_token_filling(self):
|
||||||
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
processor.patch_size = 14
|
||||||
|
processor.vision_feature_select_strategy = "default"
|
||||||
|
processor.image_processor.crop_size = {"height": 336, "width": 336}
|
||||||
|
processor.image_processor.size = {"shortest_edge": 336}
|
||||||
|
processor.image_processor.image_grid_pinpoints = [[672, 336]]
|
||||||
|
processor.num_image_tokens = (processor.image_processor.size["shortest_edge"] // processor.patch_size) ** 2
|
||||||
|
# Important to check with non square image
|
||||||
|
image = torch.randint(0, 2, (3, 503, 316))
|
||||||
|
expected_image_tokens = 1525
|
||||||
|
image_token_index = processor.image_token_id
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
inputs = processor(
|
||||||
|
text=[processor.apply_chat_template(messages)],
|
||||||
|
images=[image],
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||||
|
self.assertEqual(expected_image_tokens, image_tokens)
|
||||||
|
Loading…
Reference in New Issue
Block a user