mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add SynthID (watermerking by Google DeepMind) (#34350)
* Add SynthIDTextWatermarkLogitsProcessor * esolving comments. * Resolving comments. * esolving commits, * Improving SynthIDWatermark tests. * switch to PT version * detector as pretrained model + style * update training + style * rebase * Update logits_process.py * Improving SynthIDWatermark tests. * Shift detector training to wikitext negatives and stabilize with lower learning rate. * Clean up. * in for 7B * cleanup * upport python 3.8. * README and final cleanup. * HF Hub upload and initiaze. * Update requirements for synthid_text. * Adding SynthIDTextWatermarkDetector. * Detector testing. * Documentation changes. * Copyrights fix. * Fix detector api. * ironing out errors * ironing out errors * training checks * make fixup and make fix-copies * docstrings and add to docs * copyright * BC * test docstrings * move import * protect type hints * top level imports * watermarking example * direct imports * tpr fpr meaning * process_kwargs * SynthIDTextWatermarkingConfig docstring * assert -> exception * example updates * no immutable dict (cant be serialized) * pack fn * einsum equivalent * import order * fix test on gpu * add detector example --------- Co-authored-by: Sumedh Ghaisas <sumedhg@google.com> Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com> Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
parent
e50bf61dec
commit
b0f0c61899
@ -185,6 +185,9 @@ generation.
|
||||
[[autodoc]] SuppressTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] SynthIDTextWatermarkLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
@ -418,5 +421,20 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
## Watermark Utils
|
||||
|
||||
[[autodoc]] WatermarkingConfig
|
||||
- __call__
|
||||
|
||||
[[autodoc]] WatermarkDetector
|
||||
- __call__
|
||||
|
||||
[[autodoc]] BayesianDetectorConfig
|
||||
- __call__
|
||||
|
||||
[[autodoc]] BayesianDetectorModel
|
||||
- __call__
|
||||
|
||||
[[autodoc]] SynthIDTextWatermarkingConfig
|
||||
- __call__
|
||||
|
||||
[[autodoc]] SynthIDTextWatermarkDetector
|
||||
- __call__
|
||||
|
@ -41,8 +41,6 @@ like token streaming.
|
||||
- validate
|
||||
- get_generation_mode
|
||||
|
||||
[[autodoc]] generation.WatermarkingConfig
|
||||
|
||||
## GenerationMixin
|
||||
|
||||
[[autodoc]] GenerationMixin
|
||||
|
34
examples/research_projects/synthid_text/README.md
Normal file
34
examples/research_projects/synthid_text/README.md
Normal file
@ -0,0 +1,34 @@
|
||||
# SynthID Text
|
||||
|
||||
This project showcases the use of SynthIDText for watermarking LLMs. The code shown in this repo also
|
||||
demostrates the training of the detector for detecting such watermarked text. This detector can be uploaded onto
|
||||
a private HF hub repo (private for security reasons) and can be initialized again through pretrained model loading also shown in this script.
|
||||
|
||||
See our blog post: https://huggingface.co/blog/synthid-text
|
||||
|
||||
|
||||
## Python version
|
||||
|
||||
User would need python 3.9 to run this example.
|
||||
|
||||
## Installation and running
|
||||
|
||||
Once you install transformers you would need to install requirements for this project through requirements.txt provided in this folder.
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## To run the detector training
|
||||
|
||||
```
|
||||
python detector_training.py --model_name=google/gemma-7b-it
|
||||
```
|
||||
|
||||
Check the script for more parameters are are tunable and check out paper at link
|
||||
https://www.nature.com/articles/s41586-024-08025-4 for more information on these parameters.
|
||||
|
||||
## Caveat
|
||||
|
||||
Make sure to run the training of the detector and the detection on the same hardware
|
||||
CPU, GPU or TPU to get consistent results (we use detecterministic randomness which is hardware dependent).
|
502
examples/research_projects/synthid_text/detector_training.py
Normal file
502
examples/research_projects/synthid_text/detector_training.py
Normal file
@ -0,0 +1,502 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google DeepMind.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BayesianDetectorConfig,
|
||||
BayesianDetectorModel,
|
||||
SynthIDTextWatermarkDetector,
|
||||
SynthIDTextWatermarkingConfig,
|
||||
SynthIDTextWatermarkLogitsProcessor,
|
||||
)
|
||||
from utils import (
|
||||
get_tokenized_uwm_outputs,
|
||||
get_tokenized_wm_outputs,
|
||||
process_raw_model_outputs,
|
||||
update_fn_if_fpr_tpr,
|
||||
upload_model_to_hf,
|
||||
)
|
||||
|
||||
|
||||
@enum.unique
|
||||
class ValidationMetric(enum.Enum):
|
||||
"""Direction along the z-axis."""
|
||||
|
||||
TPR_AT_FPR = "tpr_at_fpr"
|
||||
CROSS_ENTROPY = "cross_entropy"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
"""Training arguments pertaining to the training loop itself."""
|
||||
|
||||
eval_metric: Optional[str] = dataclasses.field(
|
||||
default=ValidationMetric.TPR_AT_FPR, metadata={"help": "The evaluation metric used."}
|
||||
)
|
||||
|
||||
|
||||
def train_detector(
|
||||
detector: torch.nn.Module,
|
||||
g_values: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
watermarked: torch.Tensor,
|
||||
epochs: int = 250,
|
||||
learning_rate: float = 1e-3,
|
||||
minibatch_size: int = 64,
|
||||
seed: int = 0,
|
||||
l2_weight: float = 0.0,
|
||||
shuffle: bool = True,
|
||||
g_values_val: Optional[torch.Tensor] = None,
|
||||
mask_val: Optional[torch.Tensor] = None,
|
||||
watermarked_val: Optional[torch.Tensor] = None,
|
||||
verbose: bool = False,
|
||||
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
|
||||
) -> Tuple[Dict[str, Any], float]:
|
||||
"""Trains a Bayesian detector model.
|
||||
|
||||
Args:
|
||||
g_values: g-values of shape [num_train, seq_len, watermarking_depth].
|
||||
mask: A binary array shape [num_train, seq_len] indicating which g-values
|
||||
should be used. g-values with mask value 0 are discarded.
|
||||
watermarked: A binary array of shape [num_train] indicating whether the
|
||||
example is watermarked (0: unwatermarked, 1: watermarked).
|
||||
epochs: Number of epochs to train for.
|
||||
learning_rate: Learning rate for optimizer.
|
||||
minibatch_size: Minibatch size for training. Note that a minibatch
|
||||
requires ~ 32 * minibatch_size * seq_len * watermarked_depth *
|
||||
watermarked_depth bits of memory.
|
||||
seed: Seed for parameter initialization.
|
||||
l2_weight: Weight to apply to L2 regularization for delta parameters.
|
||||
shuffle: Whether to shuffle before training.
|
||||
g_values_val: Validation g-values of shape [num_val, seq_len,
|
||||
watermarking_depth].
|
||||
mask_val: Validation mask of shape [num_val, seq_len].
|
||||
watermarked_val: Validation watermark labels of shape [num_val].
|
||||
verbose: Boolean indicating verbosity of training. If true, the loss will
|
||||
be printed. Defaulted to False.
|
||||
use_tpr_fpr_for_val: Whether to use TPR@FPR=1% as metric for validation.
|
||||
If false, use cross entropy loss.
|
||||
|
||||
Returns:
|
||||
Tuple of
|
||||
training_history: Training history keyed by epoch number where the
|
||||
values are
|
||||
dictionaries containing the loss, validation loss, and model
|
||||
parameters,
|
||||
keyed by
|
||||
'loss', 'val_loss', and 'params', respectively.
|
||||
min_val_loss: Minimum validation loss achieved during training.
|
||||
"""
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Shuffle the data if required
|
||||
if shuffle:
|
||||
indices = torch.randperm(len(g_values))
|
||||
g_values = g_values[indices]
|
||||
mask = mask[indices]
|
||||
watermarked = watermarked[indices]
|
||||
|
||||
# Initialize optimizer
|
||||
optimizer = torch.optim.Adam(detector.parameters(), lr=learning_rate)
|
||||
history = {}
|
||||
min_val_loss = float("inf")
|
||||
|
||||
for epoch in range(epochs):
|
||||
losses = []
|
||||
detector.train()
|
||||
num_batches = len(g_values) // minibatch_size
|
||||
for i in range(0, len(g_values), minibatch_size):
|
||||
end = i + minibatch_size
|
||||
if end > len(g_values):
|
||||
break
|
||||
loss_batch_weight = l2_weight / num_batches
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = detector(
|
||||
g_values=g_values[i:end],
|
||||
mask=mask[i:end],
|
||||
labels=watermarked[i:end],
|
||||
loss_batch_weight=loss_batch_weight,
|
||||
)[1]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
losses.append(loss.item())
|
||||
train_loss = sum(losses) / len(losses)
|
||||
|
||||
val_losses = []
|
||||
if g_values_val is not None:
|
||||
detector.eval()
|
||||
if validation_metric == ValidationMetric.TPR_AT_FPR:
|
||||
val_loss = update_fn_if_fpr_tpr(
|
||||
detector,
|
||||
g_values_val,
|
||||
mask_val,
|
||||
watermarked_val,
|
||||
minibatch_size=minibatch_size,
|
||||
)
|
||||
else:
|
||||
for i in range(0, len(g_values_val), minibatch_size):
|
||||
end = i + minibatch_size
|
||||
if end > len(g_values_val):
|
||||
break
|
||||
with torch.no_grad():
|
||||
v_loss = detector(
|
||||
g_values=g_values_val[i:end],
|
||||
mask=mask_val[i:end],
|
||||
labels=watermarked_val[i:end],
|
||||
loss_batch_weight=0,
|
||||
)[1]
|
||||
val_losses.append(v_loss.item())
|
||||
val_loss = sum(val_losses) / len(val_losses)
|
||||
|
||||
# Store training history
|
||||
history[epoch + 1] = {"loss": train_loss, "val_loss": val_loss}
|
||||
if verbose:
|
||||
if val_loss is not None:
|
||||
print(f"Epoch {epoch}: loss {loss} (train), {val_loss} (val)")
|
||||
else:
|
||||
print(f"Epoch {epoch}: loss {loss} (train)")
|
||||
|
||||
if val_loss is not None and val_loss < min_val_loss:
|
||||
min_val_loss = val_loss
|
||||
best_val_epoch = epoch
|
||||
|
||||
if verbose:
|
||||
print(f"Best val Epoch: {best_val_epoch}, min_val_loss: {min_val_loss}")
|
||||
|
||||
return history, min_val_loss
|
||||
|
||||
|
||||
def train_best_detector(
|
||||
tokenized_wm_outputs: Union[List[np.ndarray], np.ndarray],
|
||||
tokenized_uwm_outputs: Union[List[np.ndarray], np.ndarray],
|
||||
logits_processor: SynthIDTextWatermarkLogitsProcessor,
|
||||
tokenizer: Any,
|
||||
torch_device: torch.device,
|
||||
test_size: float = 0.3,
|
||||
pos_truncation_length: Optional[int] = 200,
|
||||
neg_truncation_length: Optional[int] = 100,
|
||||
max_padded_length: int = 2300,
|
||||
n_epochs: int = 50,
|
||||
learning_rate: float = 2.1e-2,
|
||||
l2_weights: np.ndarray = np.logspace(-3, -2, num=4),
|
||||
verbose: bool = False,
|
||||
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
|
||||
):
|
||||
"""Train and return the best detector given range of hyperparameters.
|
||||
|
||||
In practice, we have found that tuning pos_truncation_length,
|
||||
neg_truncation_length, n_epochs, learning_rate and l2_weights can help
|
||||
improve the performance of the detector. We reccommend tuning these
|
||||
parameters for your data.
|
||||
"""
|
||||
l2_weights = list(l2_weights)
|
||||
|
||||
(
|
||||
train_g_values,
|
||||
train_masks,
|
||||
train_labels,
|
||||
cv_g_values,
|
||||
cv_masks,
|
||||
cv_labels,
|
||||
) = process_raw_model_outputs(
|
||||
logits_processor,
|
||||
tokenizer,
|
||||
pos_truncation_length,
|
||||
neg_truncation_length,
|
||||
max_padded_length,
|
||||
tokenized_wm_outputs,
|
||||
test_size,
|
||||
tokenized_uwm_outputs,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
best_detector = None
|
||||
lowest_loss = float("inf")
|
||||
val_losses = []
|
||||
for l2_weight in l2_weights:
|
||||
config = BayesianDetectorConfig(watermarking_depth=len(logits_processor.keys))
|
||||
detector = BayesianDetectorModel(config).to(torch_device)
|
||||
_, min_val_loss = train_detector(
|
||||
detector=detector,
|
||||
g_values=train_g_values,
|
||||
mask=train_masks,
|
||||
watermarked=train_labels,
|
||||
g_values_val=cv_g_values,
|
||||
mask_val=cv_masks,
|
||||
watermarked_val=cv_labels,
|
||||
learning_rate=learning_rate,
|
||||
l2_weight=l2_weight,
|
||||
epochs=n_epochs,
|
||||
verbose=verbose,
|
||||
validation_metric=validation_metric,
|
||||
)
|
||||
val_losses.append(min_val_loss)
|
||||
if min_val_loss < lowest_loss:
|
||||
lowest_loss = min_val_loss
|
||||
best_detector = detector
|
||||
return best_detector, lowest_loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="google/gemma-2b-it",
|
||||
help=("LM model to train the detector for."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help=("Temperature to sample from the model."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_k",
|
||||
type=int,
|
||||
default=40,
|
||||
help=("Top K for sampling."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help=("Top P for sampling."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_negatives",
|
||||
type=int,
|
||||
default=10000,
|
||||
help=("Number of negatives for detector training."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pos_batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help=("Batch size of watermarked positives while sampling."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_pos_batch",
|
||||
type=int,
|
||||
default=313,
|
||||
help=("Number of positive batches for training."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generation_length",
|
||||
type=int,
|
||||
default=512,
|
||||
help=("Generation length for sampling."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_model_to_hf_hub",
|
||||
action="store_true",
|
||||
help=("Whether to save the trained model HF hub. By default it will be a private repo."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_from_hf_hub",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to load trained detector model from HF Hub, make sure its the model trained on the same model "
|
||||
"we are loading in the script."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_hub_model_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("HF hub model name for loading of saving the model."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_detector_on_prompts",
|
||||
action="store_true",
|
||||
help=("Evaluate detector on a prompt and print probability of watermark."),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
model_name = args.model_name
|
||||
temperature = args.temperature
|
||||
top_k = args.top_k
|
||||
top_p = args.top_p
|
||||
num_negatives = args.num_negatives
|
||||
pos_batch_size = args.pos_batch_size
|
||||
num_pos_batch = args.num_pos_batch
|
||||
if num_pos_batch < 10:
|
||||
raise ValueError("--num_pos_batch should be greater than 10.")
|
||||
generation_length = args.generation_length
|
||||
save_model_to_hf_hub = args.save_model_to_hf_hub
|
||||
load_from_hf_hub = args.load_from_hf_hub
|
||||
repo_name = args.hf_hub_model_name
|
||||
eval_detector_on_prompts = args.eval_detector_on_prompts
|
||||
|
||||
NEG_BATCH_SIZE = 32
|
||||
|
||||
# Truncate outputs to this length for training.
|
||||
POS_TRUNCATION_LENGTH = 200
|
||||
NEG_TRUNCATION_LENGTH = 100
|
||||
# Pad trucated outputs to this length for equal shape across all batches.
|
||||
MAX_PADDED_LENGTH = 1000
|
||||
|
||||
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
if DEVICE.type not in ("cuda", "tpu"):
|
||||
raise ValueError("We have found the training stable on GPU and TPU, we are working on" " a fix for CPUs")
|
||||
|
||||
model = None
|
||||
if not load_from_hf_hub:
|
||||
# Change this to make your watermark unique. Check documentation in the paper to understand the
|
||||
# impact of these parameters.
|
||||
DEFAULT_WATERMARKING_CONFIG = {
|
||||
"ngram_len": 5, # This corresponds to H=4 context window size in the paper.
|
||||
"keys": [
|
||||
654,
|
||||
400,
|
||||
836,
|
||||
123,
|
||||
340,
|
||||
443,
|
||||
597,
|
||||
160,
|
||||
57,
|
||||
29,
|
||||
590,
|
||||
639,
|
||||
13,
|
||||
715,
|
||||
468,
|
||||
990,
|
||||
966,
|
||||
226,
|
||||
324,
|
||||
585,
|
||||
118,
|
||||
504,
|
||||
421,
|
||||
521,
|
||||
129,
|
||||
669,
|
||||
732,
|
||||
225,
|
||||
90,
|
||||
960,
|
||||
],
|
||||
"sampling_table_size": 2**16,
|
||||
"sampling_table_seed": 0,
|
||||
"context_history_size": 1024,
|
||||
}
|
||||
watermark_config = SynthIDTextWatermarkingConfig(**DEFAULT_WATERMARKING_CONFIG)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(**DEFAULT_WATERMARKING_CONFIG, device=DEVICE)
|
||||
tokenized_wm_outputs = get_tokenized_wm_outputs(
|
||||
model,
|
||||
tokenizer,
|
||||
watermark_config,
|
||||
num_pos_batch,
|
||||
pos_batch_size,
|
||||
temperature,
|
||||
generation_length,
|
||||
top_k,
|
||||
top_p,
|
||||
DEVICE,
|
||||
)
|
||||
tokenized_uwm_outputs = get_tokenized_uwm_outputs(num_negatives, NEG_BATCH_SIZE, tokenizer, DEVICE)
|
||||
|
||||
best_detector, lowest_loss = train_best_detector(
|
||||
tokenized_wm_outputs=tokenized_wm_outputs,
|
||||
tokenized_uwm_outputs=tokenized_uwm_outputs,
|
||||
logits_processor=logits_processor,
|
||||
tokenizer=tokenizer,
|
||||
torch_device=DEVICE,
|
||||
test_size=0.3,
|
||||
pos_truncation_length=POS_TRUNCATION_LENGTH,
|
||||
neg_truncation_length=NEG_TRUNCATION_LENGTH,
|
||||
max_padded_length=MAX_PADDED_LENGTH,
|
||||
n_epochs=100,
|
||||
learning_rate=3e-3,
|
||||
l2_weights=[
|
||||
0,
|
||||
],
|
||||
verbose=True,
|
||||
validation_metric=ValidationMetric.TPR_AT_FPR,
|
||||
)
|
||||
else:
|
||||
if repo_name is None:
|
||||
raise ValueError("When loading from pretrained detector model name cannot be None.")
|
||||
best_detector = BayesianDetectorModel.from_pretrained(repo_name).to(DEVICE)
|
||||
|
||||
best_detector.config.set_detector_information(
|
||||
model_name=model_name, watermarking_config=DEFAULT_WATERMARKING_CONFIG
|
||||
)
|
||||
if save_model_to_hf_hub:
|
||||
upload_model_to_hf(best_detector, repo_name)
|
||||
|
||||
# Evaluate model response with the detector
|
||||
if eval_detector_on_prompts:
|
||||
model_name = best_detector.config.model_name
|
||||
watermark_config_dict = best_detector.config.watermarking_config
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermark_config_dict, device=DEVICE)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
synthid_text_detector = SynthIDTextWatermarkDetector(best_detector, logits_processor, tokenizer)
|
||||
|
||||
if model is None:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
||||
watermarking_config = SynthIDTextWatermarkingConfig(**watermark_config_dict)
|
||||
|
||||
prompts = ["Write a essay on cats."]
|
||||
inputs = tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(DEVICE)
|
||||
|
||||
_, inputs_len = inputs["input_ids"].shape
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
watermarking_config=watermarking_config,
|
||||
do_sample=True,
|
||||
max_length=inputs_len + generation_length,
|
||||
temperature=temperature,
|
||||
top_k=40,
|
||||
top_p=1.0,
|
||||
)
|
||||
outputs = outputs[:, inputs_len:]
|
||||
result = synthid_text_detector(outputs)
|
||||
|
||||
# You should set this based on expected fpr (false positive rate) and tpr (true positive rate).
|
||||
# Check our demo at HF Spaces for more info.
|
||||
upper_threshold = 0.95
|
||||
lower_threshold = 0.12
|
||||
if result[0][0] > upper_threshold:
|
||||
print("The text is watermarked.")
|
||||
elif lower_threshold < result[0][0] < upper_threshold:
|
||||
print("It is hard to determine if the text is watermarked or not.")
|
||||
else:
|
||||
print("The text is not watermarked.")
|
5
examples/research_projects/synthid_text/requirements.txt
Normal file
5
examples/research_projects/synthid_text/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
tensorflow-datasets>=4.9.3
|
||||
torch >= 1.3
|
||||
datasets
|
||||
scikit-learn
|
||||
tensorflow
|
408
examples/research_projects/synthid_text/utils.py
Normal file
408
examples/research_projects/synthid_text/utils.py
Normal file
@ -0,0 +1,408 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google DeepMind.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import torch
|
||||
import tqdm
|
||||
from huggingface_hub import HfApi, create_repo
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from sklearn import model_selection
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
def pad_to_len(
|
||||
arr: torch.Tensor,
|
||||
target_len: int,
|
||||
left_pad: bool,
|
||||
eos_token: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Pad or truncate array to given length."""
|
||||
if arr.shape[1] < target_len:
|
||||
shape_for_ones = list(arr.shape)
|
||||
shape_for_ones[1] = target_len - shape_for_ones[1]
|
||||
padded = (
|
||||
torch.ones(
|
||||
shape_for_ones,
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
* eos_token
|
||||
)
|
||||
if not left_pad:
|
||||
arr = torch.concatenate((arr, padded), dim=1)
|
||||
else:
|
||||
arr = torch.concatenate((padded, arr), dim=1)
|
||||
else:
|
||||
arr = arr[:, :target_len]
|
||||
return arr
|
||||
|
||||
|
||||
def filter_and_truncate(
|
||||
outputs: torch.Tensor,
|
||||
truncation_length: Optional[int],
|
||||
eos_token_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Filter and truncate outputs to given length.
|
||||
|
||||
Args:
|
||||
outputs: output tensor of shape [batch_size, output_len]
|
||||
truncation_length: Length to truncate the final output.
|
||||
eos_token_mask: EOS token mask of shape [batch_size, output_len]
|
||||
|
||||
Returns:
|
||||
output tensor of shape [batch_size, truncation_length].
|
||||
"""
|
||||
if truncation_length:
|
||||
outputs = outputs[:, :truncation_length]
|
||||
truncation_mask = torch.sum(eos_token_mask, dim=1) >= truncation_length
|
||||
return outputs[truncation_mask, :]
|
||||
return outputs
|
||||
|
||||
|
||||
def process_outputs_for_training(
|
||||
all_outputs: List[torch.Tensor],
|
||||
logits_processor: transformers.generation.SynthIDTextWatermarkLogitsProcessor,
|
||||
tokenizer: Any,
|
||||
pos_truncation_length: Optional[int],
|
||||
neg_truncation_length: Optional[int],
|
||||
max_length: int,
|
||||
is_cv: bool,
|
||||
is_pos: bool,
|
||||
torch_device: torch.device,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""Process raw model outputs into format understandable by the detector.
|
||||
|
||||
Args:
|
||||
all_outputs: sequence of outputs of shape [batch_size, output_len].
|
||||
logits_processor: logits processor used for watermarking.
|
||||
tokenizer: tokenizer used for the model.
|
||||
pos_truncation_length: Length to truncate wm outputs.
|
||||
neg_truncation_length: Length to truncate uwm outputs.
|
||||
max_length: Length to pad truncated outputs so that all processed entries.
|
||||
have same shape.
|
||||
is_cv: Process given outputs for cross validation.
|
||||
is_pos: Process given outputs for positives.
|
||||
torch_device: torch device to use.
|
||||
|
||||
Returns:
|
||||
Tuple of
|
||||
all_masks: list of masks of shape [batch_size, max_length].
|
||||
all_g_values: list of g_values of shape [batch_size, max_length, depth].
|
||||
"""
|
||||
all_masks = []
|
||||
all_g_values = []
|
||||
for outputs in tqdm.tqdm(all_outputs):
|
||||
# outputs is of shape [batch_size, output_len].
|
||||
# output_len can differ from batch to batch.
|
||||
eos_token_mask = logits_processor.compute_eos_token_mask(
|
||||
input_ids=outputs,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
if is_pos or is_cv:
|
||||
# filter with length for positives for both train and CV.
|
||||
# We also filter for length when CV negatives are processed.
|
||||
outputs = filter_and_truncate(outputs, pos_truncation_length, eos_token_mask)
|
||||
elif not is_pos and not is_cv:
|
||||
outputs = filter_and_truncate(outputs, neg_truncation_length, eos_token_mask)
|
||||
|
||||
# If no filtered outputs skip this batch.
|
||||
if outputs.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# All outputs are padded to max-length with eos-tokens.
|
||||
outputs = pad_to_len(outputs, max_length, False, tokenizer.eos_token_id, torch_device)
|
||||
# outputs shape [num_filtered_entries, max_length]
|
||||
|
||||
eos_token_mask = logits_processor.compute_eos_token_mask(
|
||||
input_ids=outputs,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
context_repetition_mask = logits_processor.compute_context_repetition_mask(
|
||||
input_ids=outputs,
|
||||
)
|
||||
|
||||
# context_repetition_mask of shape [num_filtered_entries, max_length -
|
||||
# (ngram_len - 1)].
|
||||
context_repetition_mask = pad_to_len(context_repetition_mask, max_length, True, 0, torch_device)
|
||||
# We pad on left to get same max_length shape.
|
||||
# context_repetition_mask of shape [num_filtered_entries, max_length].
|
||||
combined_mask = context_repetition_mask * eos_token_mask
|
||||
|
||||
g_values = logits_processor.compute_g_values(
|
||||
input_ids=outputs,
|
||||
)
|
||||
|
||||
# g_values of shape [num_filtered_entries, max_length - (ngram_len - 1),
|
||||
# depth].
|
||||
g_values = pad_to_len(g_values, max_length, True, 0, torch_device)
|
||||
|
||||
# We pad on left to get same max_length shape.
|
||||
# g_values of shape [num_filtered_entries, max_length, depth].
|
||||
all_masks.append(combined_mask)
|
||||
all_g_values.append(g_values)
|
||||
return all_masks, all_g_values
|
||||
|
||||
|
||||
def tpr_at_fpr(detector, detector_inputs, w_true, minibatch_size, target_fpr=0.01) -> torch.Tensor:
|
||||
"""Calculates true positive rate (TPR) at false positive rate (FPR)=target_fpr."""
|
||||
positive_idxs = w_true == 1
|
||||
negative_idxs = w_true == 0
|
||||
num_samples = detector_inputs[0].size(0)
|
||||
|
||||
w_preds = []
|
||||
for start in range(0, num_samples, minibatch_size):
|
||||
end = start + minibatch_size
|
||||
detector_inputs_ = (
|
||||
detector_inputs[0][start:end],
|
||||
detector_inputs[1][start:end],
|
||||
)
|
||||
with torch.no_grad():
|
||||
w_pred = detector(*detector_inputs_)[0]
|
||||
w_preds.append(w_pred)
|
||||
|
||||
w_pred = torch.cat(w_preds, dim=0) # Concatenate predictions
|
||||
positive_scores = w_pred[positive_idxs]
|
||||
negative_scores = w_pred[negative_idxs]
|
||||
|
||||
# Calculate the FPR threshold
|
||||
# Note: percentile -> quantile
|
||||
fpr_threshold = torch.quantile(negative_scores, 1 - target_fpr)
|
||||
# Note: need to switch to FP32 since torch.mean doesn't work with torch.bool
|
||||
return torch.mean((positive_scores >= fpr_threshold).to(dtype=torch.float32)).item() # TPR
|
||||
|
||||
|
||||
def update_fn_if_fpr_tpr(detector, g_values_val, mask_val, watermarked_val, minibatch_size):
|
||||
"""Loss function for negative TPR@FPR=1% as the validation loss."""
|
||||
tpr_ = tpr_at_fpr(
|
||||
detector=detector,
|
||||
detector_inputs=(g_values_val, mask_val),
|
||||
w_true=watermarked_val,
|
||||
minibatch_size=minibatch_size,
|
||||
)
|
||||
return -tpr_
|
||||
|
||||
|
||||
def process_raw_model_outputs(
|
||||
logits_processor,
|
||||
tokenizer,
|
||||
pos_truncation_length,
|
||||
neg_truncation_length,
|
||||
max_padded_length,
|
||||
tokenized_wm_outputs,
|
||||
test_size,
|
||||
tokenized_uwm_outputs,
|
||||
torch_device,
|
||||
):
|
||||
# Split data into train and CV
|
||||
train_wm_outputs, cv_wm_outputs = model_selection.train_test_split(tokenized_wm_outputs, test_size=test_size)
|
||||
|
||||
train_uwm_outputs, cv_uwm_outputs = model_selection.train_test_split(tokenized_uwm_outputs, test_size=test_size)
|
||||
|
||||
process_kwargs = {
|
||||
"logits_processor": logits_processor,
|
||||
"tokenizer": tokenizer,
|
||||
"pos_truncation_length": pos_truncation_length,
|
||||
"neg_truncation_length": neg_truncation_length,
|
||||
"max_length": max_padded_length,
|
||||
"torch_device": torch_device,
|
||||
}
|
||||
|
||||
# Process both train and CV data for training
|
||||
wm_masks_train, wm_g_values_train = process_outputs_for_training(
|
||||
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_wm_outputs],
|
||||
is_pos=True,
|
||||
is_cv=False,
|
||||
**process_kwargs,
|
||||
)
|
||||
wm_masks_cv, wm_g_values_cv = process_outputs_for_training(
|
||||
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_wm_outputs],
|
||||
is_pos=True,
|
||||
is_cv=True,
|
||||
**process_kwargs,
|
||||
)
|
||||
uwm_masks_train, uwm_g_values_train = process_outputs_for_training(
|
||||
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_uwm_outputs],
|
||||
is_pos=False,
|
||||
is_cv=False,
|
||||
**process_kwargs,
|
||||
)
|
||||
uwm_masks_cv, uwm_g_values_cv = process_outputs_for_training(
|
||||
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_uwm_outputs],
|
||||
is_pos=False,
|
||||
is_cv=True,
|
||||
**process_kwargs,
|
||||
)
|
||||
|
||||
# We get list of data; here we concat all together to be passed to the detector.
|
||||
def pack(mask, g_values):
|
||||
mask = torch.cat(mask, dim=0)
|
||||
g = torch.cat(g_values, dim=0)
|
||||
return mask, g
|
||||
|
||||
wm_masks_train, wm_g_values_train = pack(wm_masks_train, wm_g_values_train)
|
||||
# Note: Use float instead of bool. Otherwise, the entropy calculation doesn't work
|
||||
wm_labels_train = torch.ones((wm_masks_train.shape[0],), dtype=torch.float, device=torch_device)
|
||||
|
||||
wm_masks_cv, wm_g_values_cv = pack(wm_masks_cv, wm_g_values_cv)
|
||||
wm_labels_cv = torch.ones((wm_masks_cv.shape[0],), dtype=torch.float, device=torch_device)
|
||||
|
||||
uwm_masks_train, uwm_g_values_train = pack(uwm_masks_train, uwm_g_values_train)
|
||||
uwm_labels_train = torch.zeros((uwm_masks_train.shape[0],), dtype=torch.float, device=torch_device)
|
||||
|
||||
uwm_masks_cv, uwm_g_values_cv = pack(uwm_masks_cv, uwm_g_values_cv)
|
||||
uwm_labels_cv = torch.zeros((uwm_masks_cv.shape[0],), dtype=torch.float, device=torch_device)
|
||||
|
||||
# Concat pos and negatives data together.
|
||||
train_g_values = torch.cat((wm_g_values_train, uwm_g_values_train), dim=0).squeeze()
|
||||
train_labels = torch.cat((wm_labels_train, uwm_labels_train), axis=0).squeeze()
|
||||
train_masks = torch.cat((wm_masks_train, uwm_masks_train), axis=0).squeeze()
|
||||
|
||||
cv_g_values = torch.cat((wm_g_values_cv, uwm_g_values_cv), axis=0).squeeze()
|
||||
cv_labels = torch.cat((wm_labels_cv, uwm_labels_cv), axis=0).squeeze()
|
||||
cv_masks = torch.cat((wm_masks_cv, uwm_masks_cv), axis=0).squeeze()
|
||||
|
||||
# Shuffle data.
|
||||
shuffled_idx = torch.randperm(train_g_values.shape[0]) # Use torch for GPU compatibility
|
||||
|
||||
train_g_values = train_g_values[shuffled_idx]
|
||||
train_labels = train_labels[shuffled_idx]
|
||||
train_masks = train_masks[shuffled_idx]
|
||||
|
||||
# Shuffle the cross-validation data
|
||||
shuffled_idx_cv = torch.randperm(cv_g_values.shape[0]) # Use torch for GPU compatibility
|
||||
cv_g_values = cv_g_values[shuffled_idx_cv]
|
||||
cv_labels = cv_labels[shuffled_idx_cv]
|
||||
cv_masks = cv_masks[shuffled_idx_cv]
|
||||
|
||||
# Del some variables so we free up GPU memory.
|
||||
del (
|
||||
wm_g_values_train,
|
||||
wm_labels_train,
|
||||
wm_masks_train,
|
||||
wm_g_values_cv,
|
||||
wm_labels_cv,
|
||||
wm_masks_cv,
|
||||
)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return train_g_values, train_masks, train_labels, cv_g_values, cv_masks, cv_labels
|
||||
|
||||
|
||||
def get_tokenized_uwm_outputs(num_negatives, neg_batch_size, tokenizer, device):
|
||||
dataset, info = tfds.load("wikipedia/20230601.en", split="train", with_info=True)
|
||||
dataset = dataset.take(num_negatives)
|
||||
|
||||
# Convert the dataset to a DataFrame
|
||||
df = tfds.as_dataframe(dataset, info)
|
||||
ds = tf.data.Dataset.from_tensor_slices(dict(df))
|
||||
tf.random.set_seed(0)
|
||||
ds = ds.shuffle(buffer_size=10_000)
|
||||
ds = ds.batch(batch_size=neg_batch_size)
|
||||
|
||||
tokenized_uwm_outputs = []
|
||||
# Pad to this length (on the right) for batching.
|
||||
padded_length = 1000
|
||||
for i, batch in tqdm.tqdm(enumerate(ds)):
|
||||
responses = [val.decode() for val in batch["text"].numpy()]
|
||||
inputs = tokenizer(
|
||||
responses,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(device)
|
||||
inputs = inputs["input_ids"].cpu().numpy()
|
||||
if inputs.shape[1] >= padded_length:
|
||||
inputs = inputs[:, :padded_length]
|
||||
else:
|
||||
inputs = np.concatenate(
|
||||
[inputs, np.ones((neg_batch_size, padded_length - inputs.shape[1])) * tokenizer.eos_token_id], axis=1
|
||||
)
|
||||
tokenized_uwm_outputs.append(inputs)
|
||||
if len(tokenized_uwm_outputs) * neg_batch_size > num_negatives:
|
||||
break
|
||||
return tokenized_uwm_outputs
|
||||
|
||||
|
||||
def get_tokenized_wm_outputs(
|
||||
model,
|
||||
tokenizer,
|
||||
watermark_config,
|
||||
num_pos_batches,
|
||||
pos_batch_size,
|
||||
temperature,
|
||||
max_output_len,
|
||||
top_k,
|
||||
top_p,
|
||||
device,
|
||||
):
|
||||
eli5_prompts = datasets.load_dataset("Pavithree/eli5")
|
||||
|
||||
wm_outputs = []
|
||||
|
||||
for batch_id in tqdm.tqdm(range(num_pos_batches)):
|
||||
prompts = eli5_prompts["train"]["title"][batch_id * pos_batch_size : (batch_id + 1) * pos_batch_size]
|
||||
prompts = [prompt.strip('"') for prompt in prompts]
|
||||
inputs = tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(device)
|
||||
_, inputs_len = inputs["input_ids"].shape
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
watermarking_config=watermark_config,
|
||||
do_sample=True,
|
||||
max_length=inputs_len + max_output_len,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
wm_outputs.append(outputs[:, inputs_len:].cpu().detach())
|
||||
|
||||
del outputs, inputs, prompts
|
||||
gc.collect()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return wm_outputs
|
||||
|
||||
|
||||
def upload_model_to_hf(model, hf_repo_name: str, private: bool = True):
|
||||
api = HfApi()
|
||||
|
||||
# Check if the repository exists
|
||||
try:
|
||||
api.repo_info(repo_id=hf_repo_name, use_auth_token=True)
|
||||
print(f"Repository '{hf_repo_name}' already exists.")
|
||||
except RepositoryNotFoundError:
|
||||
# If the repository does not exist, create it
|
||||
print(f"Repository '{hf_repo_name}' not found. Creating it...")
|
||||
create_repo(repo_id=hf_repo_name, private=private, use_auth_token=True)
|
||||
print(f"Repository '{hf_repo_name}' created successfully.")
|
||||
|
||||
# Push the model to the Hugging Face Hub
|
||||
print(f"Uploading model to Hugging Face repo '{hf_repo_name}'...")
|
||||
model.push_to_hub(repo_id=hf_repo_name, use_auth_token=True)
|
@ -1301,6 +1301,8 @@ else:
|
||||
_import_structure["generation"].extend(
|
||||
[
|
||||
"AlternatingCodebooksLogitsProcessor",
|
||||
"BayesianDetectorConfig",
|
||||
"BayesianDetectorModel",
|
||||
"BeamScorer",
|
||||
"BeamSearchScorer",
|
||||
"ClassifierFreeGuidanceLogitsProcessor",
|
||||
@ -1339,6 +1341,9 @@ else:
|
||||
"StopStringCriteria",
|
||||
"SuppressTokensAtBeginLogitsProcessor",
|
||||
"SuppressTokensLogitsProcessor",
|
||||
"SynthIDTextWatermarkDetector",
|
||||
"SynthIDTextWatermarkingConfig",
|
||||
"SynthIDTextWatermarkLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
@ -6213,6 +6218,8 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .generation import (
|
||||
AlternatingCodebooksLogitsProcessor,
|
||||
BayesianDetectorConfig,
|
||||
BayesianDetectorModel,
|
||||
BeamScorer,
|
||||
BeamSearchScorer,
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
@ -6251,6 +6258,9 @@ if TYPE_CHECKING:
|
||||
StopStringCriteria,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
SynthIDTextWatermarkDetector,
|
||||
SynthIDTextWatermarkingConfig,
|
||||
SynthIDTextWatermarkLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
|
@ -18,7 +18,13 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"],
|
||||
"configuration_utils": [
|
||||
"BaseWatermarkingConfig",
|
||||
"GenerationConfig",
|
||||
"GenerationMode",
|
||||
"SynthIDTextWatermarkingConfig",
|
||||
"WatermarkingConfig",
|
||||
],
|
||||
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
||||
}
|
||||
|
||||
@ -71,6 +77,7 @@ else:
|
||||
"SequenceBiasLogitsProcessor",
|
||||
"SuppressTokensLogitsProcessor",
|
||||
"SuppressTokensAtBeginLogitsProcessor",
|
||||
"SynthIDTextWatermarkLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
@ -110,6 +117,9 @@ else:
|
||||
_import_structure["watermarking"] = [
|
||||
"WatermarkDetector",
|
||||
"WatermarkDetectorOutput",
|
||||
"BayesianDetectorModel",
|
||||
"BayesianDetectorConfig",
|
||||
"SynthIDTextWatermarkDetector",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -179,7 +189,13 @@ else:
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig
|
||||
from .configuration_utils import (
|
||||
BaseWatermarkingConfig,
|
||||
GenerationConfig,
|
||||
GenerationMode,
|
||||
SynthIDTextWatermarkingConfig,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from .streamers import TextIteratorStreamer, TextStreamer
|
||||
|
||||
try:
|
||||
@ -217,6 +233,7 @@ if TYPE_CHECKING:
|
||||
SequenceBiasLogitsProcessor,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
SynthIDTextWatermarkLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
@ -254,6 +271,9 @@ if TYPE_CHECKING:
|
||||
SampleEncoderDecoderOutput,
|
||||
)
|
||||
from .watermarking import (
|
||||
BayesianDetectorConfig,
|
||||
BayesianDetectorModel,
|
||||
SynthIDTextWatermarkDetector,
|
||||
WatermarkDetector,
|
||||
WatermarkDetectorOutput,
|
||||
)
|
||||
|
@ -18,8 +18,9 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from .. import __version__
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
@ -59,6 +60,7 @@ if is_torch_available():
|
||||
StaticCache,
|
||||
StaticCacheConfig,
|
||||
)
|
||||
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
||||
|
||||
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
|
||||
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
|
||||
@ -280,23 +282,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
low_memory (`bool`, *optional*):
|
||||
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
||||
Used with beam search and contrastive search.
|
||||
watermarking_config (`WatermarkingConfig` or `dict`, *optional*):
|
||||
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens.
|
||||
If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
|
||||
See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys:
|
||||
- greenlist_ratio (`float`):
|
||||
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||
- bias (`float`):
|
||||
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
||||
- hashing_key (`int`):
|
||||
Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
||||
- seeding_scheme (`str`):
|
||||
Algorithm to use for watermarking. Accepts values:
|
||||
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
||||
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
||||
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||
- context_width (`int`):
|
||||
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||
watermarking_config (`BaseWatermarkingConfig` or `dict`, *optional*):
|
||||
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green"
|
||||
tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more
|
||||
details. If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
|
||||
|
||||
> Parameters that define the output variables of generate
|
||||
|
||||
@ -430,7 +419,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
watermarking_config = kwargs.pop("watermarking_config", None)
|
||||
if watermarking_config is None:
|
||||
self.watermarking_config = None
|
||||
elif isinstance(watermarking_config, WatermarkingConfig):
|
||||
elif isinstance(watermarking_config, BaseWatermarkingConfig):
|
||||
self.watermarking_config = watermarking_config
|
||||
else:
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
|
||||
@ -766,7 +755,15 @@ class GenerationConfig(PushToHubMixin):
|
||||
|
||||
# 6. check watermarking arguments
|
||||
if self.watermarking_config is not None:
|
||||
if not isinstance(self.watermarking_config, WatermarkingConfig):
|
||||
if not (
|
||||
isinstance(self.watermarking_config, WatermarkingConfig)
|
||||
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
||||
):
|
||||
warnings.warn(
|
||||
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
|
||||
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
@ -1287,52 +1284,20 @@ class GenerationConfig(PushToHubMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatermarkingConfig:
|
||||
"""
|
||||
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
||||
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
|
||||
|
||||
Accepts the following keys:
|
||||
- greenlist_ratio (`float`):
|
||||
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||
- bias (`float`):
|
||||
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
||||
- hashing_key (`int`):
|
||||
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
||||
- seeding_scheme (`str`):
|
||||
Algorithm to use for watermarking. Accepts values:
|
||||
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
||||
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
||||
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||
- context_width(`int`):
|
||||
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
greenlist_ratio: Optional[float] = 0.25,
|
||||
bias: Optional[float] = 2.0,
|
||||
hashing_key: Optional[int] = 15485863,
|
||||
seeding_scheme: Optional[str] = "lefthash",
|
||||
context_width: Optional[int] = 1,
|
||||
):
|
||||
self.greenlist_ratio = greenlist_ratio
|
||||
self.bias = bias
|
||||
self.hashing_key = hashing_key
|
||||
self.seeding_scheme = seeding_scheme
|
||||
self.context_width = context_width
|
||||
class BaseWatermarkingConfig(ABC):
|
||||
"""Generic watermarking config"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, **kwargs):
|
||||
"""
|
||||
Constructs a WatermarkingConfig instance from a dictionary of parameters.
|
||||
Constructs a BaseWatermarkingConfig instance from a dictionary of parameters.
|
||||
|
||||
Args:
|
||||
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
||||
**kwargs: Additional keyword arguments to override dictionary values.
|
||||
|
||||
Returns:
|
||||
WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary.
|
||||
BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary.
|
||||
"""
|
||||
config = cls(**config_dict)
|
||||
to_remove = []
|
||||
@ -1394,6 +1359,49 @@ class WatermarkingConfig:
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
@abstractmethod
|
||||
def validate(self): ...
|
||||
|
||||
@abstractmethod
|
||||
def construct_processor(self, vocab_size): ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatermarkingConfig(BaseWatermarkingConfig):
|
||||
"""
|
||||
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
||||
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
|
||||
|
||||
Accepts the following keys:
|
||||
- greenlist_ratio (`float`):
|
||||
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||
- bias (`float`):
|
||||
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
||||
- hashing_key (`int`):
|
||||
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
||||
- seeding_scheme (`str`):
|
||||
Algorithm to use for watermarking. Accepts values:
|
||||
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
||||
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
||||
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||
- context_width(`int`):
|
||||
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
greenlist_ratio: Optional[float] = 0.25,
|
||||
bias: Optional[float] = 2.0,
|
||||
hashing_key: Optional[int] = 15485863,
|
||||
seeding_scheme: Optional[str] = "lefthash",
|
||||
context_width: Optional[int] = 1,
|
||||
):
|
||||
self.greenlist_ratio = greenlist_ratio
|
||||
self.bias = bias
|
||||
self.hashing_key = hashing_key
|
||||
self.seeding_scheme = seeding_scheme
|
||||
self.context_width = context_width
|
||||
|
||||
def validate(self):
|
||||
watermark_missing_arg_msg = (
|
||||
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
||||
@ -1423,3 +1431,104 @@ class WatermarkingConfig:
|
||||
found_value=self.context_width,
|
||||
),
|
||||
)
|
||||
|
||||
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
|
||||
return WatermarkLogitsProcessor(
|
||||
vocab_size=vocab_size,
|
||||
device=device,
|
||||
greenlist_ratio=self.greenlist_ratio,
|
||||
bias=self.bias,
|
||||
hashing_key=self.hashing_key,
|
||||
seeding_scheme=self.seeding_scheme,
|
||||
context_width=self.context_width,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
|
||||
"""
|
||||
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
||||
See [this paper](https://www.nature.com/articles/s41586-024-08025-4) for more details on the arguments.
|
||||
|
||||
Args:
|
||||
ngram_len (`int`):
|
||||
Ngram length.
|
||||
keys (`List[int]`):
|
||||
A sequence of watermarking keys, one for each depth.
|
||||
context_history_size (`int`, *optional*, defaults to 1024):
|
||||
Size of the tensor to keep track of seen contexts.
|
||||
sampling_table_seed (`int`, *optional*, defaults to 0):
|
||||
Random seed to generate the sampling table.
|
||||
sampling_table_size (`int`, *optional*, defaults to 65536):
|
||||
Size of the sampling table.
|
||||
skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
|
||||
Whether to skip first ngram calls.
|
||||
debug_mode (`bool`, optional, *optional*, defaults to `False`):
|
||||
Logits are modified to uniform one got before watermarking modification is applied. This is to test the
|
||||
implementation.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
|
||||
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it')
|
||||
|
||||
>>> # SynthID Text configuration
|
||||
>>> watermarking_config = SynthIDTextWatermarkingConfig(
|
||||
... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
|
||||
... ngram_len=5,
|
||||
... )
|
||||
|
||||
>>> # Generation with watermarking
|
||||
>>> tokenized_prompts = tokenizer(["your prompts here"])
|
||||
>>> output_sequences = model.generate(
|
||||
... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True,
|
||||
... )
|
||||
>>> watermarked_text = tokenizer.batch_decode(output_sequences)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ngram_len: int,
|
||||
keys: List[int],
|
||||
context_history_size: int = 1024,
|
||||
sampling_table_seed: int = 0,
|
||||
sampling_table_size: int = 2**16,
|
||||
skip_first_ngram_calls: bool = False,
|
||||
debug_mode: bool = False,
|
||||
):
|
||||
self.ngram_len = ngram_len
|
||||
self.keys = keys
|
||||
self.sampling_table_size = sampling_table_size
|
||||
self.sampling_table_seed = sampling_table_seed
|
||||
self.context_history_size = context_history_size
|
||||
self.skip_first_ngram_calls = skip_first_ngram_calls
|
||||
self.debug_mode = debug_mode
|
||||
|
||||
def validate(self):
|
||||
watermark_missing_arg_msg = (
|
||||
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
||||
"but found {found_value}"
|
||||
)
|
||||
if self.sampling_table_size > 2**24:
|
||||
raise ValueError(
|
||||
watermark_missing_arg_msg.format(
|
||||
key="sampling_table_size",
|
||||
correct_value="< 2**24",
|
||||
found_value=self.sampling_table_size,
|
||||
),
|
||||
)
|
||||
|
||||
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
|
||||
return SynthIDTextWatermarkLogitsProcessor(
|
||||
ngram_len=self.ngram_len,
|
||||
keys=self.keys,
|
||||
sampling_table_size=self.sampling_table_size,
|
||||
sampling_table_seed=self.sampling_table_seed,
|
||||
context_history_size=self.context_history_size,
|
||||
device=device,
|
||||
skip_first_ngram_calls=self.skip_first_ngram_calls,
|
||||
debug_mode=self.debug_mode,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team
|
||||
# Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -2460,6 +2460,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
final_greenlist.append(greedy_predictions[i])
|
||||
return torch.tensor(final_greenlist, device=input_seq.device)
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if input_ids.shape[-1] < self.context_width:
|
||||
logger.warning(
|
||||
@ -2477,3 +2478,478 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias
|
||||
|
||||
return scores_processed
|
||||
|
||||
|
||||
class SynthIDTextWatermarkState:
|
||||
"""SynthID watermarking state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
ngram_len: int,
|
||||
context_history_size: int,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Initializes the state.
|
||||
|
||||
Args:
|
||||
batch_size (`int`): Batch size.
|
||||
ngram_len (`int`): Ngram length.
|
||||
context_history_size (`int`): Size of the tensor to keep track of seen contexts.
|
||||
device (`int`): Device to use.
|
||||
"""
|
||||
self.context = torch.zeros(
|
||||
(batch_size, ngram_len - 1),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.context_history = torch.zeros(
|
||||
(batch_size, context_history_size),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.num_calls = 0
|
||||
|
||||
|
||||
class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
Logits processor that implements watermarking techniques for text generation models.
|
||||
This class facilitates the application of SynthID text watermarking, a method for embedding imperceptible signals
|
||||
into generated text to aid in detecting synthetic content. It operates by subtly manipulating the probabilities of
|
||||
token selection during text generation in a manner that can be reliably recovered later for verification.
|
||||
|
||||
Key Features:
|
||||
* **State Management:** Maintains internal state to track token sequences and generate watermarking keys
|
||||
dynamically.
|
||||
|
||||
* **Key Generation:** Computes hashes based on token sequences and watermarking parameters to create unique keys
|
||||
for each position.
|
||||
|
||||
* **G-Value Sampling:** Employs a pre-computed sampling table to sample watermarking values (g-values) based on
|
||||
the generated keys.
|
||||
|
||||
* **Score Adjustment:** Applies calculated g-values to modify token probabilities during generation, embedding the
|
||||
watermark.
|
||||
|
||||
* **Context Repetition Handling:** Incorporates logic to avoid watermarking tokens in repeated contexts,
|
||||
preserving naturalness.
|
||||
|
||||
* **EOS Token Masking:** Supports masking end-of-sentence tokens to prevent their inclusion in watermarking
|
||||
calculations.
|
||||
|
||||
* **Utility Functions:** Provides functions to compute g-values directly, check for context repetition, create
|
||||
EOS token masks, and estimate expected mean g-values.
|
||||
|
||||
Refer to paper url: https://www.nature.com/articles/s41586-024-08025-4 for more details around this.
|
||||
|
||||
Args:
|
||||
ngram_len (`int`):
|
||||
Ngram length.
|
||||
keys (`List[int]`):
|
||||
A sequence of watermarking keys, one for each depth.
|
||||
sampling_table_size (`int`):
|
||||
Size of the sampling table.
|
||||
sampling_table_seed (`int`):
|
||||
Random seed to generate the sampling table.
|
||||
context_history_size (`int`):
|
||||
Size of the tensor to keep track of seen contexts.
|
||||
device (`torch.device`):
|
||||
Device to use.
|
||||
skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
|
||||
Whether to skip first ngram calls.
|
||||
debug_mode (`bool`, optional, *optional*, defaults to `False`):
|
||||
Logits are modified to uniform one got before watermarking modification is applied. This is to test the
|
||||
implementation.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
|
||||
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it')
|
||||
|
||||
>>> # SynthID Text configuration
|
||||
>>> watermarking_config = SynthIDTextWatermarkingConfig(
|
||||
... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
|
||||
... ngram_len=5,
|
||||
... )
|
||||
|
||||
>>> # Generation with watermarking
|
||||
>>> tokenized_prompts = tokenizer(["your prompts here"])
|
||||
>>> output_sequences = model.generate(
|
||||
... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True,
|
||||
... )
|
||||
>>> watermarked_text = tokenizer.batch_decode(output_sequences)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ngram_len: int,
|
||||
keys: List[int],
|
||||
sampling_table_size: int,
|
||||
sampling_table_seed: int,
|
||||
context_history_size: int,
|
||||
device: torch.device,
|
||||
skip_first_ngram_calls: bool = False,
|
||||
debug_mode: bool = False,
|
||||
):
|
||||
self.ngram_len = ngram_len
|
||||
self.keys = torch.tensor(keys, device=device)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(sampling_table_seed)
|
||||
# A random sampling table is pre-computed and modulo table size is applied to map from a hash of ngram keys to
|
||||
# g values, this is similar to the hashtable implementation used in
|
||||
# https://github.com/facebookresearch/three_bricks. We note that the hashing employed in this repository is
|
||||
# different from that used to watermark the Gemini App, and hence the detectors trained based on the
|
||||
# hashing in this repository will not transfer to text generated by the Gemini App.
|
||||
self.sampling_table = torch.randint(
|
||||
low=0,
|
||||
high=2,
|
||||
size=(sampling_table_size,),
|
||||
generator=generator,
|
||||
device=device,
|
||||
)
|
||||
self.context_history_size = context_history_size
|
||||
self.device = device
|
||||
self.state = None
|
||||
self.skip_first_ngram_calls = skip_first_ngram_calls
|
||||
self.debug_mode = debug_mode
|
||||
|
||||
def _init_state(self, batch_size: int):
|
||||
"""Initializes the state."""
|
||||
self.state = SynthIDTextWatermarkState(
|
||||
batch_size=batch_size,
|
||||
ngram_len=self.ngram_len,
|
||||
context_history_size=self.context_history_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def update_scores(self, scores: torch.FloatTensor, g_values: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""Updates scores using the g values.
|
||||
|
||||
We assume that the scores are in the log space.
|
||||
Args:
|
||||
scores (`torch.FloatTensor`): Scores (batch_size, vocab_size).
|
||||
g_values (`torch.FloatTensor`): G valus (batch_size, vocab_size, depth).
|
||||
|
||||
Returns:
|
||||
Updated scores (batch_size, vocab_size).
|
||||
"""
|
||||
_, _, depth = g_values.shape
|
||||
|
||||
probs = torch.softmax(scores, dim=1)
|
||||
|
||||
for i in range(depth):
|
||||
g_values_at_depth = g_values[:, :, i]
|
||||
g_mass_at_depth = (g_values_at_depth * probs).sum(axis=1, keepdims=True)
|
||||
probs = probs * (1 + g_values_at_depth - g_mass_at_depth)
|
||||
|
||||
log_probs = torch.log(probs)
|
||||
log_probs = torch.where(torch.isfinite(log_probs), log_probs, torch.finfo(log_probs.dtype).min)
|
||||
return log_probs
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
self._check_input_ids_shape(input_ids)
|
||||
batch_size, vocab_size = scores.shape
|
||||
|
||||
if self.debug_mode:
|
||||
scores = torch.ones_like(scores)
|
||||
|
||||
# Currently indices is just a arange to compute watermarking on the desnse logits.
|
||||
all_indices = torch.stack([torch.arange(vocab_size, device=self.device) for _ in range(batch_size)])
|
||||
|
||||
if self.state is None:
|
||||
# Initialize watermarking state if it does not exist.
|
||||
self._init_state(batch_size)
|
||||
else:
|
||||
# Append last input id (which is the input id added in last call) to the
|
||||
# previous context so we have the context to be used for current
|
||||
# watermarking.
|
||||
self.state.context = torch.concat(
|
||||
(self.state.context, input_ids[:, -1:]),
|
||||
dim=1,
|
||||
)
|
||||
self.state.context = self.state.context[:, 1:]
|
||||
|
||||
if self.state is None:
|
||||
raise ValueError("self.state can't be None! Call `self._init_state` to initialize the state.")
|
||||
|
||||
self.state.num_calls += 1
|
||||
|
||||
# Don't watermark the first ngram_len - 1 tokens if set.
|
||||
if self.skip_first_ngram_calls and self.state.num_calls < self.ngram_len:
|
||||
return scores
|
||||
|
||||
# 2. Generate random keys for each ngram key combination.
|
||||
ngram_keys, hash_result_with_just_context = self._compute_keys(self.state.context, all_indices)
|
||||
# ngram_keys shape [batch_size, top_k, depth]
|
||||
|
||||
# 3. Sample g values.
|
||||
g_values = self.sample_g_values(ngram_keys)
|
||||
# g_values shape [batch_size, top_k, depth]
|
||||
|
||||
# 4. Modify scores.
|
||||
updated_scores = self.update_scores(scores, g_values)
|
||||
# updated scores shape [batch_size, top_k]
|
||||
|
||||
# 5. Check if the current watermarking context was previously used, if yes skip watermarking.
|
||||
hash_result_with_just_context = hash_result_with_just_context[:, None]
|
||||
is_repeated_context = (self.state.context_history == hash_result_with_just_context).any(
|
||||
dim=1,
|
||||
keepdim=True,
|
||||
)
|
||||
self.state.context_history = torch.concat(
|
||||
(hash_result_with_just_context, self.state.context_history),
|
||||
dim=1,
|
||||
)[:, :-1]
|
||||
|
||||
updated_watermarked_scores = torch.where(
|
||||
is_repeated_context,
|
||||
input=scores,
|
||||
other=updated_scores,
|
||||
)
|
||||
return updated_watermarked_scores
|
||||
|
||||
def accumulate_hash(
|
||||
self,
|
||||
current_hash: torch.LongTensor,
|
||||
data: torch.LongTensor,
|
||||
multiplier: int = 6364136223846793005,
|
||||
increment: int = 1,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Accumulate hash of data on current hash.
|
||||
|
||||
Method uses adapted linear congruential generator with newlib/musl parameters.
|
||||
|
||||
This function has following property -
|
||||
f(x, data[T]) = f(f(x, data[:T - 1]), data[T])
|
||||
|
||||
This function expects current_hash.shape and data.shape[:-1] to
|
||||
match/broadcastable.
|
||||
|
||||
Args:
|
||||
current_hash (`torch.LongTensor`):
|
||||
(shape,)
|
||||
data (`torch.LongTensor`):
|
||||
(shape, tensor_len)
|
||||
multiplier (`int`, optional, *optional*, defaults to 6364136223846793005):
|
||||
multiplier of linear congruential generator
|
||||
increment (`int`, optional, *optional*, defaults to 1):
|
||||
increment of linear congruential generator
|
||||
|
||||
Returns:
|
||||
updated hash (shape,)
|
||||
"""
|
||||
for i in range(data.shape[-1]):
|
||||
current_hash = torch.add(current_hash, data[..., i])
|
||||
current_hash = torch.mul(current_hash, multiplier)
|
||||
current_hash = torch.add(current_hash, increment)
|
||||
return current_hash
|
||||
|
||||
def compute_ngram_keys(self, ngrams: torch.LongTensor) -> torch.LongTensor:
|
||||
"""Computes random keys for each ngram and depth.
|
||||
|
||||
Args:
|
||||
ngrams (`torch.LongTensor`):
|
||||
Ngrams (batch_size, num_ngrams, ngram_len).
|
||||
|
||||
Returns:
|
||||
ngram keys (batch_size, num_ngrams, depth).
|
||||
"""
|
||||
if len(ngrams.shape) != 3:
|
||||
raise ValueError(
|
||||
"Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but" f" is {ngrams.shape}"
|
||||
)
|
||||
if ngrams.shape[2] != self.ngram_len:
|
||||
raise ValueError(
|
||||
"Ngrams should be of shape (batch_size, num_ngrams, ngram_len),"
|
||||
f" where ngram_len is {self.ngram_len}, but is {ngrams.shape}"
|
||||
)
|
||||
batch_size, _, _ = ngrams.shape
|
||||
|
||||
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
|
||||
# hash_result shape [batch_size,]
|
||||
# ngrams shape [batch_size, num_ngrams, ngram_len]
|
||||
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(hash_result, ngrams)
|
||||
# hash_result shape [batch_size, num_ngrams]
|
||||
|
||||
keys = self.keys[None, None, :, None]
|
||||
# hash_result shape [batch_size, num_ngrams]
|
||||
# keys shape [1, 1, depth, 1]
|
||||
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys)
|
||||
# hash_result shape [batch_size, num_ngrams, depth]
|
||||
|
||||
return hash_result
|
||||
|
||||
def _compute_keys(
|
||||
self, n_minus_1_grams: torch.LongTensor, indices: torch.LongTensor
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
||||
"""Computes random keys for each ngram and depth.
|
||||
|
||||
Args:
|
||||
n_minus_1_grams (`torch.LongTensor`):
|
||||
Ngrams (batch_size, ngram_len - 1).
|
||||
indices (`torch.LongTensor`):
|
||||
indices of the continuations (batch_size, num_indices)
|
||||
|
||||
Returns:
|
||||
Ngram keys (batch_size, num_indices, depth).
|
||||
"""
|
||||
batch_size, _ = n_minus_1_grams.shape
|
||||
|
||||
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
|
||||
# First hash n_minus_1 gram, for each batch entry we have a single
|
||||
# n_minus_1 gram context.
|
||||
# hash_result shape [batch_size]
|
||||
# n_minus_1_gram shape [batch_size, ngram_len - 1]
|
||||
hash_result_with_just_context = self.accumulate_hash(hash_result, n_minus_1_grams)
|
||||
# hash_result shape [batch_size,]
|
||||
# Indices is of shape [batch_size, num_indices], so we make it
|
||||
# [batch_size, num_indices, 1] so we can vmap over num_indices dim.
|
||||
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(
|
||||
hash_result_with_just_context, indices[:, :, None]
|
||||
)
|
||||
# hash_result shape [batch_size, num_indices]
|
||||
# Basically we have a hash for each batch entry and each indices
|
||||
# Now we add watermarking keys to this hash.
|
||||
# keys are of shape [depth,]
|
||||
# We add batch, num_indices and data dimension to this making it
|
||||
# [1, 1, depth, 1].
|
||||
# So we can vmap over the depth dimension for compute_hash
|
||||
keys = self.keys[None, None, :, None]
|
||||
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys)
|
||||
# hash_result shape should be [batch_size, num_indices, depth]
|
||||
return hash_result, hash_result_with_just_context
|
||||
|
||||
def sample_g_values(self, ngram_keys: torch.LongTensor) -> torch.LongTensor:
|
||||
"""
|
||||
Samples g values from Bernoulli distribution.
|
||||
|
||||
It is not possible to pass random keys in a vectorized way in torch. Instead
|
||||
we pre-compute a random sampling table, and use apply modulo table size to
|
||||
map from ngram keys (int64) to g values.
|
||||
|
||||
Args:
|
||||
ngram_keys (`torch.LongTensor`):
|
||||
Random keys (batch_size, num_ngrams, depth).
|
||||
|
||||
Returns:
|
||||
G values (batch_size, num_ngrams, depth).
|
||||
"""
|
||||
(sampling_table_size,) = self.sampling_table.shape
|
||||
sampling_table = self.sampling_table.reshape((1, 1, sampling_table_size))
|
||||
ngram_keys = ngram_keys % sampling_table_size
|
||||
return torch.take_along_dim(sampling_table, indices=ngram_keys, dim=2)
|
||||
|
||||
def _check_input_ids_shape(self, input_ids: torch.LongTensor):
|
||||
"""Checks the shape of input ids."""
|
||||
if len(input_ids.shape) != 2:
|
||||
raise ValueError("Input ids should be of shape (batch_size, input_len), but is" f" {input_ids.shape}")
|
||||
|
||||
def compute_g_values(self, input_ids: torch.LongTensor) -> torch.LongTensor:
|
||||
"""
|
||||
Computes g values for each ngram from the given sequence of tokens.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor`):
|
||||
Input token ids (batch_size, input_len).
|
||||
|
||||
Returns:
|
||||
G values (batch_size, input_len - (ngram_len - 1), depth).
|
||||
"""
|
||||
self._check_input_ids_shape(input_ids)
|
||||
ngrams = input_ids.unfold(dimension=1, size=self.ngram_len, step=1)
|
||||
ngram_keys = self.compute_ngram_keys(ngrams)
|
||||
return self.sample_g_values(ngram_keys)
|
||||
|
||||
def compute_context_repetition_mask(self, input_ids: torch.LongTensor) -> torch.LongTensor:
|
||||
"""
|
||||
Computes repetition mask.
|
||||
|
||||
0 and 1 stand for repeated and not repeated context n-1 grams respectively.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor`):
|
||||
Input token ids (batch_size, input_len).
|
||||
|
||||
Returns:
|
||||
Repetitions mask (batch_size, input_len - (ngram_len - 1)).
|
||||
"""
|
||||
self._check_input_ids_shape(input_ids)
|
||||
batch_size, _ = input_ids.shape
|
||||
state = SynthIDTextWatermarkState(
|
||||
batch_size=batch_size,
|
||||
ngram_len=self.ngram_len,
|
||||
context_history_size=self.context_history_size,
|
||||
device=self.device,
|
||||
)
|
||||
contexts = input_ids[:, :-1].unfold(
|
||||
dimension=1,
|
||||
size=self.ngram_len - 1,
|
||||
step=1,
|
||||
)
|
||||
_, num_contexts, _ = contexts.shape
|
||||
|
||||
are_repeated_contexts = []
|
||||
for i in range(num_contexts):
|
||||
context = contexts[:, i, :]
|
||||
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
|
||||
context_hash = self.accumulate_hash(hash_result, context)[:, None]
|
||||
is_repeated_context = (state.context_history == context_hash).any(
|
||||
dim=1,
|
||||
keepdim=True,
|
||||
)
|
||||
are_repeated_contexts.append(is_repeated_context)
|
||||
state.context_history = torch.concat(
|
||||
(context_hash, state.context_history),
|
||||
dim=1,
|
||||
)[:, :-1]
|
||||
are_repeated_contexts = torch.concat(are_repeated_contexts, dim=1)
|
||||
|
||||
return torch.logical_not(are_repeated_contexts)
|
||||
|
||||
def compute_eos_token_mask(self, input_ids: torch.LongTensor, eos_token_id: int) -> torch.LongTensor:
|
||||
"""
|
||||
Computes repetitions mask.
|
||||
|
||||
1 stands for ngrams that don't contain EOS tokens and vice versa.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor`):
|
||||
Input token ids (batch_size, input_len).
|
||||
eos_token_id (`int`):
|
||||
EOS token ID.
|
||||
|
||||
Returns:
|
||||
EOS token mask (batch_size, input_len).
|
||||
"""
|
||||
self._check_input_ids_shape(input_ids)
|
||||
noneos_masks = []
|
||||
all_eos_equated = input_ids == eos_token_id
|
||||
for eos_equated in all_eos_equated:
|
||||
nonzero_idx = torch.nonzero(eos_equated)
|
||||
noneos_mask = torch.ones_like(eos_equated)
|
||||
if nonzero_idx.shape[0] != 0:
|
||||
noneos_mask[nonzero_idx[0][0] :] = 0
|
||||
noneos_masks.append(noneos_mask)
|
||||
return torch.stack(noneos_masks, dim=0)
|
||||
|
||||
def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) -> float:
|
||||
"""
|
||||
Compute expected mean g-value after watermarking, assuming uniform LM dist.
|
||||
|
||||
This is the theoretical expected value for single-layer watermarking.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`):
|
||||
The size of the vocabulary.
|
||||
coinflip_prob arg_name (`float`, *optional*, defaults to 0.5):
|
||||
Probability of 1 in boolean prf.
|
||||
|
||||
Returns:
|
||||
The expected mean g-value for watermarked text.
|
||||
"""
|
||||
return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))
|
||||
|
@ -92,7 +92,6 @@ from .logits_process import (
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
ConfidenceCriteria,
|
||||
@ -1011,15 +1010,7 @@ class GenerationMixin:
|
||||
)
|
||||
if generation_config.watermarking_config is not None:
|
||||
processors.append(
|
||||
WatermarkLogitsProcessor(
|
||||
vocab_size=self.config.vocab_size,
|
||||
device=device,
|
||||
greenlist_ratio=generation_config.watermarking_config.greenlist_ratio,
|
||||
bias=generation_config.watermarking_config.bias,
|
||||
hashing_key=generation_config.watermarking_config.hashing_key,
|
||||
seeding_scheme=generation_config.watermarking_config.seeding_scheme,
|
||||
context_width=generation_config.watermarking_config.context_width,
|
||||
)
|
||||
generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
|
||||
)
|
||||
|
||||
# TODO (joao): find a strategy to specify the order of the processors
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team
|
||||
# Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -16,19 +16,22 @@
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import BCELoss
|
||||
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..utils import is_torch_available, logging
|
||||
from .configuration_utils import WatermarkingConfig
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..utils import ModelOutput, is_torch_available, logging
|
||||
from .configuration_utils import PretrainedConfig, WatermarkingConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from .logits_process import WatermarkLogitsProcessor
|
||||
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -237,3 +240,310 @@ class WatermarkDetector:
|
||||
confidence=confidence,
|
||||
)
|
||||
return prediction
|
||||
|
||||
|
||||
class BayesianDetectorConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`BayesianDetectorModel`]. It is used to
|
||||
instantiate a Bayesian Detector model according to the specified arguments.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
watermarking_depth (`int`, *optional*):
|
||||
The number of tournament layers.
|
||||
base_rate (`float1`, *optional*, defaults to 0.5):
|
||||
Prior probability P(w) that a text is watermarked.
|
||||
"""
|
||||
|
||||
def __init__(self, watermarking_depth: int = None, base_rate: float = 0.5, **kwargs):
|
||||
self.watermarking_depth = watermarking_depth
|
||||
self.base_rate = base_rate
|
||||
# These can be set later to store information about this detector.
|
||||
self.model_name = None
|
||||
self.watermarking_config = None
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_detector_information(self, model_name, watermarking_config):
|
||||
self.model_name = model_name
|
||||
self.watermarking_config = watermarking_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class BayesianWatermarkDetectorModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of models predicting if the text is watermarked.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss.
|
||||
posterior_probabilities (`torch.FloatTensor` of shape `(1,)`):
|
||||
Multiple choice classification loss.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
posterior_probabilities: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class BayesianDetectorWatermarkedLikelihood(nn.Module):
|
||||
"""Watermarked likelihood model for binary-valued g-values.
|
||||
|
||||
This takes in g-values and returns p(g_values|watermarked).
|
||||
"""
|
||||
|
||||
def __init__(self, watermarking_depth: int):
|
||||
"""Initializes the model parameters."""
|
||||
super().__init__()
|
||||
self.watermarking_depth = watermarking_depth
|
||||
self.beta = torch.nn.Parameter(-2.5 + 0.001 * torch.randn(1, 1, watermarking_depth))
|
||||
self.delta = torch.nn.Parameter(0.001 * torch.randn(1, 1, self.watermarking_depth, watermarking_depth))
|
||||
|
||||
def _compute_latents(self, g_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Computes the unique token probability distribution given g-values.
|
||||
|
||||
Args:
|
||||
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
|
||||
PRF values.
|
||||
|
||||
Returns:
|
||||
p_one_unique_token and p_two_unique_tokens, both of shape
|
||||
[batch_size, seq_len, watermarking_depth]. p_one_unique_token[i,t,l]
|
||||
gives the probability of there being one unique token in a tournament
|
||||
match on layer l, on timestep t, for batch item i.
|
||||
p_one_unique_token[i,t,l] + p_two_unique_token[i,t,l] = 1.
|
||||
"""
|
||||
# Tile g-values to produce feature vectors for predicting the latents
|
||||
# for each layer in the tournament; our model for the latents psi is a
|
||||
# logistic regression model psi = sigmoid(delta * x + beta).
|
||||
|
||||
# [batch_size, seq_len, watermarking_depth, watermarking_depth]
|
||||
x = torch.repeat_interleave(torch.unsqueeze(g_values, dim=-2), self.watermarking_depth, axis=-2)
|
||||
|
||||
# mask all elements above -1 diagonal for autoregressive factorization
|
||||
x = torch.tril(x, diagonal=-1)
|
||||
|
||||
# [batch_size, seq_len, watermarking_depth]
|
||||
# (i, j, k, l) x (i, j, k, l) -> (i, j, k) einsum equivalent
|
||||
logits = (self.delta[..., None, :] @ x.type(self.delta.dtype)[..., None]).squeeze() + self.beta
|
||||
|
||||
p_two_unique_tokens = torch.sigmoid(logits)
|
||||
p_one_unique_token = 1 - p_two_unique_tokens
|
||||
return p_one_unique_token, p_two_unique_tokens
|
||||
|
||||
def forward(self, g_values: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes the likelihoods P(g_values|watermarked).
|
||||
|
||||
Args:
|
||||
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
|
||||
g-values (values 0 or 1)
|
||||
|
||||
Returns:
|
||||
p(g_values|watermarked) of shape [batch_size, seq_len, watermarking_depth].
|
||||
"""
|
||||
p_one_unique_token, p_two_unique_tokens = self._compute_latents(g_values)
|
||||
|
||||
# P(g_tl | watermarked) is equal to
|
||||
# 0.5 * [ (g_tl+0.5) * p_two_unique_tokens + p_one_unique_token].
|
||||
return 0.5 * ((g_values + 0.5) * p_two_unique_tokens + p_one_unique_token)
|
||||
|
||||
|
||||
class BayesianDetectorModel(PreTrainedModel):
|
||||
r"""
|
||||
Bayesian classifier for watermark detection.
|
||||
|
||||
This detector uses Bayes' rule to compute a watermarking score, which is the sigmoid of the log of ratio of the
|
||||
posterior probabilities P(watermarked|g_values) and P(unwatermarked|g_values). Please see the section on
|
||||
BayesianScore in the paper for further details.
|
||||
Paper URL: https://www.nature.com/articles/s41586-024-08025-4
|
||||
|
||||
Note that this detector only works with non-distortionary Tournament-based watermarking using the Bernoulli(0.5)
|
||||
g-value distribution.
|
||||
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`BayesianDetectorConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
config_class = BayesianDetectorConfig
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.watermarking_depth = config.watermarking_depth
|
||||
self.base_rate = config.base_rate
|
||||
self.likelihood_model_watermarked = BayesianDetectorWatermarkedLikelihood(
|
||||
watermarking_depth=self.watermarking_depth
|
||||
)
|
||||
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Parameter):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
|
||||
def _compute_posterior(
|
||||
self,
|
||||
likelihoods_watermarked: torch.Tensor,
|
||||
likelihoods_unwatermarked: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
prior: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute posterior P(w|g) given likelihoods, mask and prior.
|
||||
|
||||
Args:
|
||||
likelihoods_watermarked (`torch.Tensor` of shape `(batch, length, depth)`):
|
||||
Likelihoods P(g_values|watermarked) of g-values under watermarked model.
|
||||
likelihoods_unwatermarked (`torch.Tensor` of shape `(batch, length, depth)`):
|
||||
Likelihoods P(g_values|unwatermarked) of g-values under unwatermarked model.
|
||||
mask (`torch.Tensor` of shape `(batch, length)`):
|
||||
A binary array indicating which g-values should be used. g-values with mask value 0 are discarded.
|
||||
prior (`float`):
|
||||
the prior probability P(w) that the text is watermarked.
|
||||
|
||||
Returns:
|
||||
Posterior probability P(watermarked|g_values), shape [batch].
|
||||
"""
|
||||
mask = torch.unsqueeze(mask, dim=-1)
|
||||
prior = torch.clamp(prior, min=1e-5, max=1 - 1e-5)
|
||||
log_likelihoods_watermarked = torch.log(torch.clamp(likelihoods_watermarked, min=1e-30, max=float("inf")))
|
||||
log_likelihoods_unwatermarked = torch.log(torch.clamp(likelihoods_unwatermarked, min=1e-30, max=float("inf")))
|
||||
log_odds = log_likelihoods_watermarked - log_likelihoods_unwatermarked
|
||||
|
||||
# Sum relative surprisals (log odds) across all token positions and layers.
|
||||
relative_surprisal_likelihood = torch.einsum("i...->i", log_odds * mask)
|
||||
|
||||
# Compute the relative surprisal prior
|
||||
relative_surprisal_prior = torch.log(prior) - torch.log(1 - prior)
|
||||
|
||||
# Combine prior and likelihood.
|
||||
# [batch_size]
|
||||
relative_surprisal = relative_surprisal_prior + relative_surprisal_likelihood
|
||||
|
||||
# Compute the posterior probability P(w|g) = sigmoid(relative_surprisal).
|
||||
return torch.sigmoid(relative_surprisal)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
g_values: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
loss_batch_weight=1,
|
||||
return_dict=False,
|
||||
) -> BayesianWatermarkDetectorModelOutput:
|
||||
"""
|
||||
Computes the watermarked posterior P(watermarked|g_values).
|
||||
|
||||
Args:
|
||||
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth, ...)`):
|
||||
g-values (with values 0 or 1)
|
||||
mask:
|
||||
A binary array shape [batch_size, seq_len] indicating which g-values should be used. g-values with mask
|
||||
value 0 are discarded.
|
||||
|
||||
Returns:
|
||||
p(watermarked | g_values), of shape [batch_size].
|
||||
"""
|
||||
|
||||
likelihoods_watermarked = self.likelihood_model_watermarked(g_values)
|
||||
likelihoods_unwatermarked = 0.5 * torch.ones_like(g_values)
|
||||
out = self._compute_posterior(
|
||||
likelihoods_watermarked=likelihoods_watermarked,
|
||||
likelihoods_unwatermarked=likelihoods_unwatermarked,
|
||||
mask=mask,
|
||||
prior=self.prior,
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = BCELoss()
|
||||
loss_unwweight = torch.sum(self.likelihood_model_watermarked.delta**2)
|
||||
loss_weight = loss_unwweight * loss_batch_weight
|
||||
loss = loss_fct(torch.clamp(out, 1e-5, 1 - 1e-5), labels) + loss_weight
|
||||
|
||||
if not return_dict:
|
||||
return (out,) if loss is None else (out, loss)
|
||||
|
||||
return BayesianWatermarkDetectorModelOutput(loss=loss, posterior_probabilities=out)
|
||||
|
||||
|
||||
class SynthIDTextWatermarkDetector:
|
||||
r"""
|
||||
SynthID text watermark detector class.
|
||||
|
||||
This class has to be initialized with the trained bayesian detector module check script
|
||||
in examples/synthid_text/detector_training.py for example in training/saving/loading this
|
||||
detector module. The folder also showcases example use case of this detector.
|
||||
|
||||
Parameters:
|
||||
detector_module ([`BayesianDetectorModel`]):
|
||||
Bayesian detector module object initialized with parameters.
|
||||
Check examples/research_projects/synthid_text/detector_training.py for usage.
|
||||
logits_processor (`SynthIDTextWatermarkLogitsProcessor`):
|
||||
The logits processor used for watermarking.
|
||||
tokenizer (`Any`):
|
||||
The tokenizer used for the model.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... AutoTokenizer, BayesianDetectorModel, SynthIDTextWatermarkLogitsProcessor, SynthIDTextWatermarkDetector
|
||||
... )
|
||||
|
||||
>>> # Load the detector. See examples/research_projects/synthid_text for training a detector.
|
||||
>>> detector_model = BayesianDetectorModel.from_pretrained("joaogante/dummy_synthid_detector")
|
||||
>>> logits_processor = SynthIDTextWatermarkLogitsProcessor(
|
||||
... **detector_model.config.watermarking_config, device="cpu"
|
||||
... )
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(detector_model.config.model_name)
|
||||
>>> detector = SynthIDTextWatermarkDetector(detector_model, logits_processor, tokenizer)
|
||||
|
||||
>>> # Test whether a certain string is watermarked
|
||||
>>> test_input = tokenizer(["This is a test input"], return_tensors="pt")
|
||||
>>> is_watermarked = detector(test_input.input_ids)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detector_module: BayesianDetectorModel,
|
||||
logits_processor: SynthIDTextWatermarkLogitsProcessor,
|
||||
tokenizer: Any,
|
||||
):
|
||||
self.detector_module = detector_module
|
||||
self.logits_processor = logits_processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, tokenized_outputs: torch.Tensor):
|
||||
# eos mask is computed, skip first ngram_len - 1 tokens
|
||||
# eos_mask will be of shape [batch_size, output_len]
|
||||
eos_token_mask = self.logits_processor.compute_eos_token_mask(
|
||||
input_ids=tokenized_outputs,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
)[:, self.logits_processor.ngram_len - 1 :]
|
||||
|
||||
# context repetition mask is computed
|
||||
context_repetition_mask = self.logits_processor.compute_context_repetition_mask(
|
||||
input_ids=tokenized_outputs,
|
||||
)
|
||||
# context repitition mask shape [batch_size, output_len - (ngram_len - 1)]
|
||||
|
||||
combined_mask = context_repetition_mask * eos_token_mask
|
||||
|
||||
g_values = self.logits_processor.compute_g_values(
|
||||
input_ids=tokenized_outputs,
|
||||
)
|
||||
# g values shape [batch_size, output_len - (ngram_len - 1), depth]
|
||||
return self.detector_module(g_values, combined_mask)
|
||||
|
@ -191,6 +191,20 @@ class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BayesianDetectorConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BayesianDetectorModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BeamScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -457,6 +471,27 @@ class SuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SynthIDTextWatermarkDetector(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SynthIDTextWatermarkingConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SynthIDTextWatermarkLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
import unittest
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import is_torch_available
|
||||
@ -48,6 +49,7 @@ if is_torch_available():
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
SynthIDTextWatermarkLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
@ -975,3 +977,187 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores_wo_bias = scores[:, -1].clone()
|
||||
out = watermark(input_ids=input_ids, scores=scores)
|
||||
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
|
||||
|
||||
@parameterized.expand([(5, 3, 10000), (10, 5, 1000)])
|
||||
def test_synthidtext_watermarking_processor_bias_uniformity(self, ngram_len, num_layers, vocab_size):
|
||||
"""Test SynthID watermarked distribution bias uniformity over iterations."""
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
watermarking_config = {
|
||||
"ngram_len": ngram_len,
|
||||
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
|
||||
"sampling_table_size": 2**16,
|
||||
"sampling_table_seed": 0,
|
||||
"context_history_size": 512,
|
||||
"device": torch_device,
|
||||
}
|
||||
batch_size = 100000
|
||||
ngrams = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, ngram_len),
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||
g_values = logits_processor.compute_g_values(ngrams)
|
||||
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=0))
|
||||
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.01)
|
||||
|
||||
@parameterized.expand([(10000, 3), (1000, 20)])
|
||||
def test_synthidtext_watermark_processor_bias_uniformity_across_vocab(self, vocab_size, num_layers):
|
||||
"""Test SynthID watermarked distribution bias uniformity over vocabs of the model."""
|
||||
batch_size = 1000
|
||||
ngram_len = 5
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
watermarking_config = {
|
||||
"ngram_len": ngram_len,
|
||||
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
|
||||
"sampling_table_size": 2**16,
|
||||
"sampling_table_seed": 0,
|
||||
"context_history_size": 512,
|
||||
"device": torch_device,
|
||||
}
|
||||
n_minus_1_grams = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, watermarking_config["ngram_len"] - 1),
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||
ngram_keys, _ = logits_processor._compute_keys(
|
||||
n_minus_1_grams,
|
||||
torch.stack([torch.arange(vocab_size, device=torch_device) for _ in range(batch_size)]),
|
||||
)
|
||||
|
||||
g_values = logits_processor.sample_g_values(ngram_keys)
|
||||
# g_values shape should be [batch_size, vocab_size, num_layers]
|
||||
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1))
|
||||
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001)
|
||||
|
||||
@parameterized.expand([(2, "uniform"), (10, "uniform"), (2, "random"), (10, "random")])
|
||||
def test_synthidtext_watermark_processor_distributional_convergence(self, vocab_size, logits_type):
|
||||
"""Check if watermarked distribution converges to unwatermarked logits distribution."""
|
||||
batch_size = 1500
|
||||
num_keys = 1000
|
||||
|
||||
updated_softmaxes = 0
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
if logits_type == "uniform":
|
||||
fixed_logits = torch.ones((batch_size, vocab_size), device=torch_device)
|
||||
elif logits_type == "random":
|
||||
fixed_logits = torch.rand(
|
||||
(
|
||||
1,
|
||||
vocab_size,
|
||||
),
|
||||
device=torch_device,
|
||||
)
|
||||
fixed_logits = fixed_logits.repeat(batch_size, 1)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized logits_type {logits_type}")
|
||||
for _ in range(num_keys):
|
||||
watermarking_config = {
|
||||
"ngram_len": 5,
|
||||
"keys": np.random.randint(0, 10**9, size=(1,), dtype=np.int64),
|
||||
"sampling_table_size": 2**16,
|
||||
"sampling_table_seed": 0,
|
||||
"context_history_size": 1024,
|
||||
"device": torch_device,
|
||||
}
|
||||
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||
|
||||
ngrams = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, watermarking_config["ngram_len"]),
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
# Insert ngram-1 into logit_processor state.
|
||||
for idx in range(watermarking_config["ngram_len"] - 1):
|
||||
_ = logits_processor(ngrams[:, :idx], fixed_logits)
|
||||
|
||||
updated_scores = logits_processor(ngrams, fixed_logits)
|
||||
updated_softmaxes += torch.nn.functional.softmax(updated_scores, dim=1).cpu().numpy()
|
||||
|
||||
updated_softmaxes = np.mean(updated_softmaxes, axis=0) / num_keys
|
||||
is_close = torch.all(
|
||||
torch.isclose(
|
||||
torch.tensor(updated_softmaxes, device=torch_device),
|
||||
torch.nn.Softmax()(fixed_logits[0]), # Take any batch entry, all are same.
|
||||
atol=1e-3,
|
||||
rtol=0,
|
||||
)
|
||||
)
|
||||
self.assertTrue(is_close)
|
||||
|
||||
@parameterized.expand([(2, 10, 1, 0.01), (100, 5, 1, 0.01), (100, 10, 2, 0.02)])
|
||||
def test_synthidtext_watermark_processor_bias_test(self, vocab_size, ngram_len, num_layers, atol):
|
||||
"""Test SynthID watermarking bias matches theoretical value."""
|
||||
batch_size = 20000
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
keys = [np.random.randint(0, 10**9) for _ in range(num_layers)]
|
||||
# Use 10**9 rather than vocab_size to ensure variety in (n-1)-grams.
|
||||
context = torch.randint(
|
||||
low=0,
|
||||
high=10**9,
|
||||
size=(batch_size, ngram_len - 1),
|
||||
dtype=torch.int64,
|
||||
generator=generator,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
context_history_size = 1024
|
||||
logits_processor = SynthIDTextWatermarkLogitsProcessor(
|
||||
ngram_len=ngram_len,
|
||||
keys=keys,
|
||||
sampling_table_size=2**16,
|
||||
sampling_table_seed=0,
|
||||
context_history_size=context_history_size,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
scores = torch.ones(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float64,
|
||||
device=torch_device,
|
||||
)
|
||||
# Init state of the logits processor.
|
||||
logits_processor(context, scores)
|
||||
# insert context into the state.
|
||||
for idx in range(1, ngram_len - 1):
|
||||
_ = logits_processor(context[:, :idx], scores)
|
||||
|
||||
updated_scores = logits_processor(context, scores)
|
||||
|
||||
probs = torch.nn.functional.softmax(updated_scores, dim=1)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
next_tokens = torch.multinomial(
|
||||
probs,
|
||||
num_samples=1,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
ngrams = torch.concat((context, next_tokens), dim=1)
|
||||
g_values = logits_processor.compute_g_values(ngrams)
|
||||
mean_g_values = g_values.mean(dtype=torch.float64, dim=(0, 1))
|
||||
|
||||
expected_mean_g_value = logits_processor.expected_mean_g_value(
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
is_close = torch.all(
|
||||
torch.isclose(
|
||||
mean_g_values,
|
||||
torch.tensor(expected_mean_g_value, dtype=torch.float64, device=torch_device),
|
||||
atol=atol,
|
||||
rtol=0,
|
||||
)
|
||||
)
|
||||
self.assertTrue(is_close)
|
||||
|
@ -84,6 +84,7 @@ if is_torch_available():
|
||||
SampleEncoderDecoderOutput,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
SynthIDTextWatermarkingConfig,
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
@ -2517,9 +2518,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@slow
|
||||
def test_watermark_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
def test_green_red_watermark_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||
input_len = model_inputs["input_ids"].shape[-1]
|
||||
@ -2548,6 +2549,61 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
|
||||
self.assertListEqual(detection_out.prediction.tolist(), [False])
|
||||
|
||||
"""Check the mean bias inserted by the watermarking algorithm."""
|
||||
|
||||
@slow
|
||||
def test_synthid_text_watermark_generation_mean_expected_bias(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||
input_len = 5
|
||||
batch_size = 200
|
||||
|
||||
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
|
||||
watermark_config = SynthIDTextWatermarkingConfig(keys=[10, 20], ngram_len=5, debug_mode=True)
|
||||
logits_processor = watermark_config.construct_processor(model.config.vocab_size, torch_device)
|
||||
mean_g_values_repeats = []
|
||||
for _ in range(40):
|
||||
input_ids = torch.zeros(
|
||||
(batch_size, input_len),
|
||||
dtype=torch.int64,
|
||||
device=torch_device,
|
||||
)
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": torch.ones_like(input_ids, device=torch_device),
|
||||
}
|
||||
output = model.generate(
|
||||
**model_inputs, watermarking_config=watermark_config, do_sample=True, max_length=500, top_k=1000
|
||||
)
|
||||
g_values = logits_processor.compute_g_values(input_ids=output[:, input_len:])
|
||||
context_repetition_mask = logits_processor.compute_context_repetition_mask(
|
||||
input_ids=output[:, input_len:],
|
||||
).unsqueeze(dim=2)
|
||||
|
||||
mean_g_values = torch.masked.mean(
|
||||
g_values,
|
||||
mask=context_repetition_mask,
|
||||
dim=0,
|
||||
keepdim=True,
|
||||
dtype=torch.float64,
|
||||
)
|
||||
mean_g_values_repeats.append(mean_g_values)
|
||||
|
||||
mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0)
|
||||
expected_mean_g_value = logits_processor.expected_mean_g_value(
|
||||
vocab_size=model.config.vocab_size,
|
||||
)
|
||||
atol = 0.03
|
||||
is_close = torch.isclose(
|
||||
mean_g_values,
|
||||
torch.tensor(expected_mean_g_value, dtype=torch.float64),
|
||||
atol=atol,
|
||||
rtol=0,
|
||||
)
|
||||
self.assertTrue(torch.all(is_close))
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||
|
Loading…
Reference in New Issue
Block a user