Add SynthID (watermerking by Google DeepMind) (#34350)

* Add SynthIDTextWatermarkLogitsProcessor

* esolving comments.

* Resolving comments.

* esolving commits,

* Improving SynthIDWatermark tests.

* switch to PT version

* detector as pretrained model + style

* update training + style

* rebase

* Update logits_process.py

* Improving SynthIDWatermark tests.

* Shift detector training to wikitext negatives and stabilize with lower learning rate.

* Clean up.

* in for 7B

* cleanup

* upport python 3.8.

* README and final cleanup.

* HF Hub upload and initiaze.

* Update requirements for synthid_text.

* Adding SynthIDTextWatermarkDetector.

* Detector testing.

* Documentation changes.

* Copyrights fix.

* Fix detector api.

* ironing out errors

* ironing out errors

* training checks

* make fixup and make fix-copies

* docstrings and add to docs

* copyright

* BC

* test docstrings

* move import

* protect type hints

* top level imports

* watermarking example

* direct imports

* tpr fpr meaning

* process_kwargs

* SynthIDTextWatermarkingConfig docstring

* assert -> exception

* example updates

* no immutable dict (cant be serialized)

* pack fn

* einsum equivalent

* import order

* fix test on gpu

* add detector example

---------

Co-authored-by: Sumedh Ghaisas <sumedhg@google.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
Joao Gante 2024-10-23 21:18:52 +01:00 committed by GitHub
parent e50bf61dec
commit b0f0c61899
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 2238 additions and 80 deletions

View File

@ -185,6 +185,9 @@ generation.
[[autodoc]] SuppressTokensLogitsProcessor
- __call__
[[autodoc]] SynthIDTextWatermarkLogitsProcessor
- __call__
[[autodoc]] TemperatureLogitsWarper
- __call__
@ -418,5 +421,20 @@ A [`Constraint`] can be used to force the generation to include specific tokens
## Watermark Utils
[[autodoc]] WatermarkingConfig
- __call__
[[autodoc]] WatermarkDetector
- __call__
[[autodoc]] BayesianDetectorConfig
- __call__
[[autodoc]] BayesianDetectorModel
- __call__
[[autodoc]] SynthIDTextWatermarkingConfig
- __call__
[[autodoc]] SynthIDTextWatermarkDetector
- __call__

View File

@ -41,8 +41,6 @@ like token streaming.
- validate
- get_generation_mode
[[autodoc]] generation.WatermarkingConfig
## GenerationMixin
[[autodoc]] GenerationMixin

View File

@ -0,0 +1,34 @@
# SynthID Text
This project showcases the use of SynthIDText for watermarking LLMs. The code shown in this repo also
demostrates the training of the detector for detecting such watermarked text. This detector can be uploaded onto
a private HF hub repo (private for security reasons) and can be initialized again through pretrained model loading also shown in this script.
See our blog post: https://huggingface.co/blog/synthid-text
## Python version
User would need python 3.9 to run this example.
## Installation and running
Once you install transformers you would need to install requirements for this project through requirements.txt provided in this folder.
```
pip install -r requirements.txt
```
## To run the detector training
```
python detector_training.py --model_name=google/gemma-7b-it
```
Check the script for more parameters are are tunable and check out paper at link
https://www.nature.com/articles/s41586-024-08025-4 for more information on these parameters.
## Caveat
Make sure to run the training of the detector and the detection on the same hardware
CPU, GPU or TPU to get consistent results (we use detecterministic randomness which is hardware dependent).

View File

@ -0,0 +1,502 @@
# coding=utf-8
# Copyright 2024 Google DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import dataclasses
import enum
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BayesianDetectorConfig,
BayesianDetectorModel,
SynthIDTextWatermarkDetector,
SynthIDTextWatermarkingConfig,
SynthIDTextWatermarkLogitsProcessor,
)
from utils import (
get_tokenized_uwm_outputs,
get_tokenized_wm_outputs,
process_raw_model_outputs,
update_fn_if_fpr_tpr,
upload_model_to_hf,
)
@enum.unique
class ValidationMetric(enum.Enum):
"""Direction along the z-axis."""
TPR_AT_FPR = "tpr_at_fpr"
CROSS_ENTROPY = "cross_entropy"
@dataclasses.dataclass
class TrainingArguments:
"""Training arguments pertaining to the training loop itself."""
eval_metric: Optional[str] = dataclasses.field(
default=ValidationMetric.TPR_AT_FPR, metadata={"help": "The evaluation metric used."}
)
def train_detector(
detector: torch.nn.Module,
g_values: torch.Tensor,
mask: torch.Tensor,
watermarked: torch.Tensor,
epochs: int = 250,
learning_rate: float = 1e-3,
minibatch_size: int = 64,
seed: int = 0,
l2_weight: float = 0.0,
shuffle: bool = True,
g_values_val: Optional[torch.Tensor] = None,
mask_val: Optional[torch.Tensor] = None,
watermarked_val: Optional[torch.Tensor] = None,
verbose: bool = False,
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
) -> Tuple[Dict[str, Any], float]:
"""Trains a Bayesian detector model.
Args:
g_values: g-values of shape [num_train, seq_len, watermarking_depth].
mask: A binary array shape [num_train, seq_len] indicating which g-values
should be used. g-values with mask value 0 are discarded.
watermarked: A binary array of shape [num_train] indicating whether the
example is watermarked (0: unwatermarked, 1: watermarked).
epochs: Number of epochs to train for.
learning_rate: Learning rate for optimizer.
minibatch_size: Minibatch size for training. Note that a minibatch
requires ~ 32 * minibatch_size * seq_len * watermarked_depth *
watermarked_depth bits of memory.
seed: Seed for parameter initialization.
l2_weight: Weight to apply to L2 regularization for delta parameters.
shuffle: Whether to shuffle before training.
g_values_val: Validation g-values of shape [num_val, seq_len,
watermarking_depth].
mask_val: Validation mask of shape [num_val, seq_len].
watermarked_val: Validation watermark labels of shape [num_val].
verbose: Boolean indicating verbosity of training. If true, the loss will
be printed. Defaulted to False.
use_tpr_fpr_for_val: Whether to use TPR@FPR=1% as metric for validation.
If false, use cross entropy loss.
Returns:
Tuple of
training_history: Training history keyed by epoch number where the
values are
dictionaries containing the loss, validation loss, and model
parameters,
keyed by
'loss', 'val_loss', and 'params', respectively.
min_val_loss: Minimum validation loss achieved during training.
"""
# Set the random seed for reproducibility
torch.manual_seed(seed)
# Shuffle the data if required
if shuffle:
indices = torch.randperm(len(g_values))
g_values = g_values[indices]
mask = mask[indices]
watermarked = watermarked[indices]
# Initialize optimizer
optimizer = torch.optim.Adam(detector.parameters(), lr=learning_rate)
history = {}
min_val_loss = float("inf")
for epoch in range(epochs):
losses = []
detector.train()
num_batches = len(g_values) // minibatch_size
for i in range(0, len(g_values), minibatch_size):
end = i + minibatch_size
if end > len(g_values):
break
loss_batch_weight = l2_weight / num_batches
optimizer.zero_grad()
loss = detector(
g_values=g_values[i:end],
mask=mask[i:end],
labels=watermarked[i:end],
loss_batch_weight=loss_batch_weight,
)[1]
loss.backward()
optimizer.step()
losses.append(loss.item())
train_loss = sum(losses) / len(losses)
val_losses = []
if g_values_val is not None:
detector.eval()
if validation_metric == ValidationMetric.TPR_AT_FPR:
val_loss = update_fn_if_fpr_tpr(
detector,
g_values_val,
mask_val,
watermarked_val,
minibatch_size=minibatch_size,
)
else:
for i in range(0, len(g_values_val), minibatch_size):
end = i + minibatch_size
if end > len(g_values_val):
break
with torch.no_grad():
v_loss = detector(
g_values=g_values_val[i:end],
mask=mask_val[i:end],
labels=watermarked_val[i:end],
loss_batch_weight=0,
)[1]
val_losses.append(v_loss.item())
val_loss = sum(val_losses) / len(val_losses)
# Store training history
history[epoch + 1] = {"loss": train_loss, "val_loss": val_loss}
if verbose:
if val_loss is not None:
print(f"Epoch {epoch}: loss {loss} (train), {val_loss} (val)")
else:
print(f"Epoch {epoch}: loss {loss} (train)")
if val_loss is not None and val_loss < min_val_loss:
min_val_loss = val_loss
best_val_epoch = epoch
if verbose:
print(f"Best val Epoch: {best_val_epoch}, min_val_loss: {min_val_loss}")
return history, min_val_loss
def train_best_detector(
tokenized_wm_outputs: Union[List[np.ndarray], np.ndarray],
tokenized_uwm_outputs: Union[List[np.ndarray], np.ndarray],
logits_processor: SynthIDTextWatermarkLogitsProcessor,
tokenizer: Any,
torch_device: torch.device,
test_size: float = 0.3,
pos_truncation_length: Optional[int] = 200,
neg_truncation_length: Optional[int] = 100,
max_padded_length: int = 2300,
n_epochs: int = 50,
learning_rate: float = 2.1e-2,
l2_weights: np.ndarray = np.logspace(-3, -2, num=4),
verbose: bool = False,
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
):
"""Train and return the best detector given range of hyperparameters.
In practice, we have found that tuning pos_truncation_length,
neg_truncation_length, n_epochs, learning_rate and l2_weights can help
improve the performance of the detector. We reccommend tuning these
parameters for your data.
"""
l2_weights = list(l2_weights)
(
train_g_values,
train_masks,
train_labels,
cv_g_values,
cv_masks,
cv_labels,
) = process_raw_model_outputs(
logits_processor,
tokenizer,
pos_truncation_length,
neg_truncation_length,
max_padded_length,
tokenized_wm_outputs,
test_size,
tokenized_uwm_outputs,
torch_device,
)
best_detector = None
lowest_loss = float("inf")
val_losses = []
for l2_weight in l2_weights:
config = BayesianDetectorConfig(watermarking_depth=len(logits_processor.keys))
detector = BayesianDetectorModel(config).to(torch_device)
_, min_val_loss = train_detector(
detector=detector,
g_values=train_g_values,
mask=train_masks,
watermarked=train_labels,
g_values_val=cv_g_values,
mask_val=cv_masks,
watermarked_val=cv_labels,
learning_rate=learning_rate,
l2_weight=l2_weight,
epochs=n_epochs,
verbose=verbose,
validation_metric=validation_metric,
)
val_losses.append(min_val_loss)
if min_val_loss < lowest_loss:
lowest_loss = min_val_loss
best_detector = detector
return best_detector, lowest_loss
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="google/gemma-2b-it",
help=("LM model to train the detector for."),
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help=("Temperature to sample from the model."),
)
parser.add_argument(
"--top_k",
type=int,
default=40,
help=("Top K for sampling."),
)
parser.add_argument(
"--top_p",
type=float,
default=1.0,
help=("Top P for sampling."),
)
parser.add_argument(
"--num_negatives",
type=int,
default=10000,
help=("Number of negatives for detector training."),
)
parser.add_argument(
"--pos_batch_size",
type=int,
default=32,
help=("Batch size of watermarked positives while sampling."),
)
parser.add_argument(
"--num_pos_batch",
type=int,
default=313,
help=("Number of positive batches for training."),
)
parser.add_argument(
"--generation_length",
type=int,
default=512,
help=("Generation length for sampling."),
)
parser.add_argument(
"--save_model_to_hf_hub",
action="store_true",
help=("Whether to save the trained model HF hub. By default it will be a private repo."),
)
parser.add_argument(
"--load_from_hf_hub",
action="store_true",
help=(
"Whether to load trained detector model from HF Hub, make sure its the model trained on the same model "
"we are loading in the script."
),
)
parser.add_argument(
"--hf_hub_model_name",
type=str,
default=None,
help=("HF hub model name for loading of saving the model."),
)
parser.add_argument(
"--eval_detector_on_prompts",
action="store_true",
help=("Evaluate detector on a prompt and print probability of watermark."),
)
args = parser.parse_args()
model_name = args.model_name
temperature = args.temperature
top_k = args.top_k
top_p = args.top_p
num_negatives = args.num_negatives
pos_batch_size = args.pos_batch_size
num_pos_batch = args.num_pos_batch
if num_pos_batch < 10:
raise ValueError("--num_pos_batch should be greater than 10.")
generation_length = args.generation_length
save_model_to_hf_hub = args.save_model_to_hf_hub
load_from_hf_hub = args.load_from_hf_hub
repo_name = args.hf_hub_model_name
eval_detector_on_prompts = args.eval_detector_on_prompts
NEG_BATCH_SIZE = 32
# Truncate outputs to this length for training.
POS_TRUNCATION_LENGTH = 200
NEG_TRUNCATION_LENGTH = 100
# Pad trucated outputs to this length for equal shape across all batches.
MAX_PADDED_LENGTH = 1000
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
if DEVICE.type not in ("cuda", "tpu"):
raise ValueError("We have found the training stable on GPU and TPU, we are working on" " a fix for CPUs")
model = None
if not load_from_hf_hub:
# Change this to make your watermark unique. Check documentation in the paper to understand the
# impact of these parameters.
DEFAULT_WATERMARKING_CONFIG = {
"ngram_len": 5, # This corresponds to H=4 context window size in the paper.
"keys": [
654,
400,
836,
123,
340,
443,
597,
160,
57,
29,
590,
639,
13,
715,
468,
990,
966,
226,
324,
585,
118,
504,
421,
521,
129,
669,
732,
225,
90,
960,
],
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 1024,
}
watermark_config = SynthIDTextWatermarkingConfig(**DEFAULT_WATERMARKING_CONFIG)
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
logits_processor = SynthIDTextWatermarkLogitsProcessor(**DEFAULT_WATERMARKING_CONFIG, device=DEVICE)
tokenized_wm_outputs = get_tokenized_wm_outputs(
model,
tokenizer,
watermark_config,
num_pos_batch,
pos_batch_size,
temperature,
generation_length,
top_k,
top_p,
DEVICE,
)
tokenized_uwm_outputs = get_tokenized_uwm_outputs(num_negatives, NEG_BATCH_SIZE, tokenizer, DEVICE)
best_detector, lowest_loss = train_best_detector(
tokenized_wm_outputs=tokenized_wm_outputs,
tokenized_uwm_outputs=tokenized_uwm_outputs,
logits_processor=logits_processor,
tokenizer=tokenizer,
torch_device=DEVICE,
test_size=0.3,
pos_truncation_length=POS_TRUNCATION_LENGTH,
neg_truncation_length=NEG_TRUNCATION_LENGTH,
max_padded_length=MAX_PADDED_LENGTH,
n_epochs=100,
learning_rate=3e-3,
l2_weights=[
0,
],
verbose=True,
validation_metric=ValidationMetric.TPR_AT_FPR,
)
else:
if repo_name is None:
raise ValueError("When loading from pretrained detector model name cannot be None.")
best_detector = BayesianDetectorModel.from_pretrained(repo_name).to(DEVICE)
best_detector.config.set_detector_information(
model_name=model_name, watermarking_config=DEFAULT_WATERMARKING_CONFIG
)
if save_model_to_hf_hub:
upload_model_to_hf(best_detector, repo_name)
# Evaluate model response with the detector
if eval_detector_on_prompts:
model_name = best_detector.config.model_name
watermark_config_dict = best_detector.config.watermarking_config
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermark_config_dict, device=DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
synthid_text_detector = SynthIDTextWatermarkDetector(best_detector, logits_processor, tokenizer)
if model is None:
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
watermarking_config = SynthIDTextWatermarkingConfig(**watermark_config_dict)
prompts = ["Write a essay on cats."]
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
).to(DEVICE)
_, inputs_len = inputs["input_ids"].shape
outputs = model.generate(
**inputs,
watermarking_config=watermarking_config,
do_sample=True,
max_length=inputs_len + generation_length,
temperature=temperature,
top_k=40,
top_p=1.0,
)
outputs = outputs[:, inputs_len:]
result = synthid_text_detector(outputs)
# You should set this based on expected fpr (false positive rate) and tpr (true positive rate).
# Check our demo at HF Spaces for more info.
upper_threshold = 0.95
lower_threshold = 0.12
if result[0][0] > upper_threshold:
print("The text is watermarked.")
elif lower_threshold < result[0][0] < upper_threshold:
print("It is hard to determine if the text is watermarked or not.")
else:
print("The text is not watermarked.")

View File

@ -0,0 +1,5 @@
tensorflow-datasets>=4.9.3
torch >= 1.3
datasets
scikit-learn
tensorflow

View File

@ -0,0 +1,408 @@
# coding=utf-8
# Copyright 2024 Google DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
from typing import Any, List, Optional, Tuple
import datasets
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import tqdm
from huggingface_hub import HfApi, create_repo
from huggingface_hub.utils import RepositoryNotFoundError
from sklearn import model_selection
import transformers
def pad_to_len(
arr: torch.Tensor,
target_len: int,
left_pad: bool,
eos_token: int,
device: torch.device,
) -> torch.Tensor:
"""Pad or truncate array to given length."""
if arr.shape[1] < target_len:
shape_for_ones = list(arr.shape)
shape_for_ones[1] = target_len - shape_for_ones[1]
padded = (
torch.ones(
shape_for_ones,
device=device,
dtype=torch.long,
)
* eos_token
)
if not left_pad:
arr = torch.concatenate((arr, padded), dim=1)
else:
arr = torch.concatenate((padded, arr), dim=1)
else:
arr = arr[:, :target_len]
return arr
def filter_and_truncate(
outputs: torch.Tensor,
truncation_length: Optional[int],
eos_token_mask: torch.Tensor,
) -> torch.Tensor:
"""Filter and truncate outputs to given length.
Args:
outputs: output tensor of shape [batch_size, output_len]
truncation_length: Length to truncate the final output.
eos_token_mask: EOS token mask of shape [batch_size, output_len]
Returns:
output tensor of shape [batch_size, truncation_length].
"""
if truncation_length:
outputs = outputs[:, :truncation_length]
truncation_mask = torch.sum(eos_token_mask, dim=1) >= truncation_length
return outputs[truncation_mask, :]
return outputs
def process_outputs_for_training(
all_outputs: List[torch.Tensor],
logits_processor: transformers.generation.SynthIDTextWatermarkLogitsProcessor,
tokenizer: Any,
pos_truncation_length: Optional[int],
neg_truncation_length: Optional[int],
max_length: int,
is_cv: bool,
is_pos: bool,
torch_device: torch.device,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Process raw model outputs into format understandable by the detector.
Args:
all_outputs: sequence of outputs of shape [batch_size, output_len].
logits_processor: logits processor used for watermarking.
tokenizer: tokenizer used for the model.
pos_truncation_length: Length to truncate wm outputs.
neg_truncation_length: Length to truncate uwm outputs.
max_length: Length to pad truncated outputs so that all processed entries.
have same shape.
is_cv: Process given outputs for cross validation.
is_pos: Process given outputs for positives.
torch_device: torch device to use.
Returns:
Tuple of
all_masks: list of masks of shape [batch_size, max_length].
all_g_values: list of g_values of shape [batch_size, max_length, depth].
"""
all_masks = []
all_g_values = []
for outputs in tqdm.tqdm(all_outputs):
# outputs is of shape [batch_size, output_len].
# output_len can differ from batch to batch.
eos_token_mask = logits_processor.compute_eos_token_mask(
input_ids=outputs,
eos_token_id=tokenizer.eos_token_id,
)
if is_pos or is_cv:
# filter with length for positives for both train and CV.
# We also filter for length when CV negatives are processed.
outputs = filter_and_truncate(outputs, pos_truncation_length, eos_token_mask)
elif not is_pos and not is_cv:
outputs = filter_and_truncate(outputs, neg_truncation_length, eos_token_mask)
# If no filtered outputs skip this batch.
if outputs.shape[0] == 0:
continue
# All outputs are padded to max-length with eos-tokens.
outputs = pad_to_len(outputs, max_length, False, tokenizer.eos_token_id, torch_device)
# outputs shape [num_filtered_entries, max_length]
eos_token_mask = logits_processor.compute_eos_token_mask(
input_ids=outputs,
eos_token_id=tokenizer.eos_token_id,
)
context_repetition_mask = logits_processor.compute_context_repetition_mask(
input_ids=outputs,
)
# context_repetition_mask of shape [num_filtered_entries, max_length -
# (ngram_len - 1)].
context_repetition_mask = pad_to_len(context_repetition_mask, max_length, True, 0, torch_device)
# We pad on left to get same max_length shape.
# context_repetition_mask of shape [num_filtered_entries, max_length].
combined_mask = context_repetition_mask * eos_token_mask
g_values = logits_processor.compute_g_values(
input_ids=outputs,
)
# g_values of shape [num_filtered_entries, max_length - (ngram_len - 1),
# depth].
g_values = pad_to_len(g_values, max_length, True, 0, torch_device)
# We pad on left to get same max_length shape.
# g_values of shape [num_filtered_entries, max_length, depth].
all_masks.append(combined_mask)
all_g_values.append(g_values)
return all_masks, all_g_values
def tpr_at_fpr(detector, detector_inputs, w_true, minibatch_size, target_fpr=0.01) -> torch.Tensor:
"""Calculates true positive rate (TPR) at false positive rate (FPR)=target_fpr."""
positive_idxs = w_true == 1
negative_idxs = w_true == 0
num_samples = detector_inputs[0].size(0)
w_preds = []
for start in range(0, num_samples, minibatch_size):
end = start + minibatch_size
detector_inputs_ = (
detector_inputs[0][start:end],
detector_inputs[1][start:end],
)
with torch.no_grad():
w_pred = detector(*detector_inputs_)[0]
w_preds.append(w_pred)
w_pred = torch.cat(w_preds, dim=0) # Concatenate predictions
positive_scores = w_pred[positive_idxs]
negative_scores = w_pred[negative_idxs]
# Calculate the FPR threshold
# Note: percentile -> quantile
fpr_threshold = torch.quantile(negative_scores, 1 - target_fpr)
# Note: need to switch to FP32 since torch.mean doesn't work with torch.bool
return torch.mean((positive_scores >= fpr_threshold).to(dtype=torch.float32)).item() # TPR
def update_fn_if_fpr_tpr(detector, g_values_val, mask_val, watermarked_val, minibatch_size):
"""Loss function for negative TPR@FPR=1% as the validation loss."""
tpr_ = tpr_at_fpr(
detector=detector,
detector_inputs=(g_values_val, mask_val),
w_true=watermarked_val,
minibatch_size=minibatch_size,
)
return -tpr_
def process_raw_model_outputs(
logits_processor,
tokenizer,
pos_truncation_length,
neg_truncation_length,
max_padded_length,
tokenized_wm_outputs,
test_size,
tokenized_uwm_outputs,
torch_device,
):
# Split data into train and CV
train_wm_outputs, cv_wm_outputs = model_selection.train_test_split(tokenized_wm_outputs, test_size=test_size)
train_uwm_outputs, cv_uwm_outputs = model_selection.train_test_split(tokenized_uwm_outputs, test_size=test_size)
process_kwargs = {
"logits_processor": logits_processor,
"tokenizer": tokenizer,
"pos_truncation_length": pos_truncation_length,
"neg_truncation_length": neg_truncation_length,
"max_length": max_padded_length,
"torch_device": torch_device,
}
# Process both train and CV data for training
wm_masks_train, wm_g_values_train = process_outputs_for_training(
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_wm_outputs],
is_pos=True,
is_cv=False,
**process_kwargs,
)
wm_masks_cv, wm_g_values_cv = process_outputs_for_training(
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_wm_outputs],
is_pos=True,
is_cv=True,
**process_kwargs,
)
uwm_masks_train, uwm_g_values_train = process_outputs_for_training(
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_uwm_outputs],
is_pos=False,
is_cv=False,
**process_kwargs,
)
uwm_masks_cv, uwm_g_values_cv = process_outputs_for_training(
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_uwm_outputs],
is_pos=False,
is_cv=True,
**process_kwargs,
)
# We get list of data; here we concat all together to be passed to the detector.
def pack(mask, g_values):
mask = torch.cat(mask, dim=0)
g = torch.cat(g_values, dim=0)
return mask, g
wm_masks_train, wm_g_values_train = pack(wm_masks_train, wm_g_values_train)
# Note: Use float instead of bool. Otherwise, the entropy calculation doesn't work
wm_labels_train = torch.ones((wm_masks_train.shape[0],), dtype=torch.float, device=torch_device)
wm_masks_cv, wm_g_values_cv = pack(wm_masks_cv, wm_g_values_cv)
wm_labels_cv = torch.ones((wm_masks_cv.shape[0],), dtype=torch.float, device=torch_device)
uwm_masks_train, uwm_g_values_train = pack(uwm_masks_train, uwm_g_values_train)
uwm_labels_train = torch.zeros((uwm_masks_train.shape[0],), dtype=torch.float, device=torch_device)
uwm_masks_cv, uwm_g_values_cv = pack(uwm_masks_cv, uwm_g_values_cv)
uwm_labels_cv = torch.zeros((uwm_masks_cv.shape[0],), dtype=torch.float, device=torch_device)
# Concat pos and negatives data together.
train_g_values = torch.cat((wm_g_values_train, uwm_g_values_train), dim=0).squeeze()
train_labels = torch.cat((wm_labels_train, uwm_labels_train), axis=0).squeeze()
train_masks = torch.cat((wm_masks_train, uwm_masks_train), axis=0).squeeze()
cv_g_values = torch.cat((wm_g_values_cv, uwm_g_values_cv), axis=0).squeeze()
cv_labels = torch.cat((wm_labels_cv, uwm_labels_cv), axis=0).squeeze()
cv_masks = torch.cat((wm_masks_cv, uwm_masks_cv), axis=0).squeeze()
# Shuffle data.
shuffled_idx = torch.randperm(train_g_values.shape[0]) # Use torch for GPU compatibility
train_g_values = train_g_values[shuffled_idx]
train_labels = train_labels[shuffled_idx]
train_masks = train_masks[shuffled_idx]
# Shuffle the cross-validation data
shuffled_idx_cv = torch.randperm(cv_g_values.shape[0]) # Use torch for GPU compatibility
cv_g_values = cv_g_values[shuffled_idx_cv]
cv_labels = cv_labels[shuffled_idx_cv]
cv_masks = cv_masks[shuffled_idx_cv]
# Del some variables so we free up GPU memory.
del (
wm_g_values_train,
wm_labels_train,
wm_masks_train,
wm_g_values_cv,
wm_labels_cv,
wm_masks_cv,
)
gc.collect()
torch.cuda.empty_cache()
return train_g_values, train_masks, train_labels, cv_g_values, cv_masks, cv_labels
def get_tokenized_uwm_outputs(num_negatives, neg_batch_size, tokenizer, device):
dataset, info = tfds.load("wikipedia/20230601.en", split="train", with_info=True)
dataset = dataset.take(num_negatives)
# Convert the dataset to a DataFrame
df = tfds.as_dataframe(dataset, info)
ds = tf.data.Dataset.from_tensor_slices(dict(df))
tf.random.set_seed(0)
ds = ds.shuffle(buffer_size=10_000)
ds = ds.batch(batch_size=neg_batch_size)
tokenized_uwm_outputs = []
# Pad to this length (on the right) for batching.
padded_length = 1000
for i, batch in tqdm.tqdm(enumerate(ds)):
responses = [val.decode() for val in batch["text"].numpy()]
inputs = tokenizer(
responses,
return_tensors="pt",
padding=True,
).to(device)
inputs = inputs["input_ids"].cpu().numpy()
if inputs.shape[1] >= padded_length:
inputs = inputs[:, :padded_length]
else:
inputs = np.concatenate(
[inputs, np.ones((neg_batch_size, padded_length - inputs.shape[1])) * tokenizer.eos_token_id], axis=1
)
tokenized_uwm_outputs.append(inputs)
if len(tokenized_uwm_outputs) * neg_batch_size > num_negatives:
break
return tokenized_uwm_outputs
def get_tokenized_wm_outputs(
model,
tokenizer,
watermark_config,
num_pos_batches,
pos_batch_size,
temperature,
max_output_len,
top_k,
top_p,
device,
):
eli5_prompts = datasets.load_dataset("Pavithree/eli5")
wm_outputs = []
for batch_id in tqdm.tqdm(range(num_pos_batches)):
prompts = eli5_prompts["train"]["title"][batch_id * pos_batch_size : (batch_id + 1) * pos_batch_size]
prompts = [prompt.strip('"') for prompt in prompts]
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
).to(device)
_, inputs_len = inputs["input_ids"].shape
outputs = model.generate(
**inputs,
watermarking_config=watermark_config,
do_sample=True,
max_length=inputs_len + max_output_len,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
wm_outputs.append(outputs[:, inputs_len:].cpu().detach())
del outputs, inputs, prompts
gc.collect()
gc.collect()
torch.cuda.empty_cache()
return wm_outputs
def upload_model_to_hf(model, hf_repo_name: str, private: bool = True):
api = HfApi()
# Check if the repository exists
try:
api.repo_info(repo_id=hf_repo_name, use_auth_token=True)
print(f"Repository '{hf_repo_name}' already exists.")
except RepositoryNotFoundError:
# If the repository does not exist, create it
print(f"Repository '{hf_repo_name}' not found. Creating it...")
create_repo(repo_id=hf_repo_name, private=private, use_auth_token=True)
print(f"Repository '{hf_repo_name}' created successfully.")
# Push the model to the Hugging Face Hub
print(f"Uploading model to Hugging Face repo '{hf_repo_name}'...")
model.push_to_hub(repo_id=hf_repo_name, use_auth_token=True)

View File

@ -1301,6 +1301,8 @@ else:
_import_structure["generation"].extend(
[
"AlternatingCodebooksLogitsProcessor",
"BayesianDetectorConfig",
"BayesianDetectorModel",
"BeamScorer",
"BeamSearchScorer",
"ClassifierFreeGuidanceLogitsProcessor",
@ -1339,6 +1341,9 @@ else:
"StopStringCriteria",
"SuppressTokensAtBeginLogitsProcessor",
"SuppressTokensLogitsProcessor",
"SynthIDTextWatermarkDetector",
"SynthIDTextWatermarkingConfig",
"SynthIDTextWatermarkLogitsProcessor",
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
@ -6213,6 +6218,8 @@ if TYPE_CHECKING:
)
from .generation import (
AlternatingCodebooksLogitsProcessor,
BayesianDetectorConfig,
BayesianDetectorModel,
BeamScorer,
BeamSearchScorer,
ClassifierFreeGuidanceLogitsProcessor,
@ -6251,6 +6258,9 @@ if TYPE_CHECKING:
StopStringCriteria,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
SynthIDTextWatermarkDetector,
SynthIDTextWatermarkingConfig,
SynthIDTextWatermarkLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,

View File

@ -18,7 +18,13 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
_import_structure = {
"configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"],
"configuration_utils": [
"BaseWatermarkingConfig",
"GenerationConfig",
"GenerationMode",
"SynthIDTextWatermarkingConfig",
"WatermarkingConfig",
],
"streamers": ["TextIteratorStreamer", "TextStreamer"],
}
@ -71,6 +77,7 @@ else:
"SequenceBiasLogitsProcessor",
"SuppressTokensLogitsProcessor",
"SuppressTokensAtBeginLogitsProcessor",
"SynthIDTextWatermarkLogitsProcessor",
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
@ -110,6 +117,9 @@ else:
_import_structure["watermarking"] = [
"WatermarkDetector",
"WatermarkDetectorOutput",
"BayesianDetectorModel",
"BayesianDetectorConfig",
"SynthIDTextWatermarkDetector",
]
try:
@ -179,7 +189,13 @@ else:
]
if TYPE_CHECKING:
from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig
from .configuration_utils import (
BaseWatermarkingConfig,
GenerationConfig,
GenerationMode,
SynthIDTextWatermarkingConfig,
WatermarkingConfig,
)
from .streamers import TextIteratorStreamer, TextStreamer
try:
@ -217,6 +233,7 @@ if TYPE_CHECKING:
SequenceBiasLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
SynthIDTextWatermarkLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
@ -254,6 +271,9 @@ if TYPE_CHECKING:
SampleEncoderDecoderOutput,
)
from .watermarking import (
BayesianDetectorConfig,
BayesianDetectorModel,
SynthIDTextWatermarkDetector,
WatermarkDetector,
WatermarkDetectorOutput,
)

View File

@ -18,8 +18,9 @@ import copy
import json
import os
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from .. import __version__
from ..configuration_utils import PretrainedConfig
@ -59,6 +60,7 @@ if is_torch_available():
StaticCache,
StaticCacheConfig,
)
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
@ -280,23 +282,10 @@ class GenerationConfig(PushToHubMixin):
low_memory (`bool`, *optional*):
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
Used with beam search and contrastive search.
watermarking_config (`WatermarkingConfig` or `dict`, *optional*):
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens.
If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys:
- greenlist_ratio (`float`):
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
- bias (`float`):
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
- hashing_key (`int`):
Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime).
- seeding_scheme (`str`):
Algorithm to use for watermarking. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
- context_width (`int`):
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
watermarking_config (`BaseWatermarkingConfig` or `dict`, *optional*):
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green"
tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more
details. If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
> Parameters that define the output variables of generate
@ -430,7 +419,7 @@ class GenerationConfig(PushToHubMixin):
watermarking_config = kwargs.pop("watermarking_config", None)
if watermarking_config is None:
self.watermarking_config = None
elif isinstance(watermarking_config, WatermarkingConfig):
elif isinstance(watermarking_config, BaseWatermarkingConfig):
self.watermarking_config = watermarking_config
else:
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
@ -766,7 +755,15 @@ class GenerationConfig(PushToHubMixin):
# 6. check watermarking arguments
if self.watermarking_config is not None:
if not isinstance(self.watermarking_config, WatermarkingConfig):
if not (
isinstance(self.watermarking_config, WatermarkingConfig)
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
):
warnings.warn(
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
FutureWarning,
)
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
@ -1287,52 +1284,20 @@ class GenerationConfig(PushToHubMixin):
@dataclass
class WatermarkingConfig:
"""
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
Accepts the following keys:
- greenlist_ratio (`float`):
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
- bias (`float`):
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
- hashing_key (`int`):
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
- seeding_scheme (`str`):
Algorithm to use for watermarking. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
- context_width(`int`):
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
"""
def __init__(
self,
greenlist_ratio: Optional[float] = 0.25,
bias: Optional[float] = 2.0,
hashing_key: Optional[int] = 15485863,
seeding_scheme: Optional[str] = "lefthash",
context_width: Optional[int] = 1,
):
self.greenlist_ratio = greenlist_ratio
self.bias = bias
self.hashing_key = hashing_key
self.seeding_scheme = seeding_scheme
self.context_width = context_width
class BaseWatermarkingConfig(ABC):
"""Generic watermarking config"""
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a WatermarkingConfig instance from a dictionary of parameters.
Constructs a BaseWatermarkingConfig instance from a dictionary of parameters.
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary.
BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
@ -1394,6 +1359,49 @@ class WatermarkingConfig:
if hasattr(self, key):
setattr(self, key, value)
@abstractmethod
def validate(self): ...
@abstractmethod
def construct_processor(self, vocab_size): ...
@dataclass
class WatermarkingConfig(BaseWatermarkingConfig):
"""
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
Accepts the following keys:
- greenlist_ratio (`float`):
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
- bias (`float`):
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
- hashing_key (`int`):
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
- seeding_scheme (`str`):
Algorithm to use for watermarking. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
- context_width(`int`):
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
"""
def __init__(
self,
greenlist_ratio: Optional[float] = 0.25,
bias: Optional[float] = 2.0,
hashing_key: Optional[int] = 15485863,
seeding_scheme: Optional[str] = "lefthash",
context_width: Optional[int] = 1,
):
self.greenlist_ratio = greenlist_ratio
self.bias = bias
self.hashing_key = hashing_key
self.seeding_scheme = seeding_scheme
self.context_width = context_width
def validate(self):
watermark_missing_arg_msg = (
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
@ -1423,3 +1431,104 @@ class WatermarkingConfig:
found_value=self.context_width,
),
)
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
return WatermarkLogitsProcessor(
vocab_size=vocab_size,
device=device,
greenlist_ratio=self.greenlist_ratio,
bias=self.bias,
hashing_key=self.hashing_key,
seeding_scheme=self.seeding_scheme,
context_width=self.context_width,
)
@dataclass
class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
"""
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
See [this paper](https://www.nature.com/articles/s41586-024-08025-4) for more details on the arguments.
Args:
ngram_len (`int`):
Ngram length.
keys (`List[int]`):
A sequence of watermarking keys, one for each depth.
context_history_size (`int`, *optional*, defaults to 1024):
Size of the tensor to keep track of seen contexts.
sampling_table_seed (`int`, *optional*, defaults to 0):
Random seed to generate the sampling table.
sampling_table_size (`int`, *optional*, defaults to 65536):
Size of the sampling table.
skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
Whether to skip first ngram calls.
debug_mode (`bool`, optional, *optional*, defaults to `False`):
Logits are modified to uniform one got before watermarking modification is applied. This is to test the
implementation.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it')
>>> # SynthID Text configuration
>>> watermarking_config = SynthIDTextWatermarkingConfig(
... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
... ngram_len=5,
... )
>>> # Generation with watermarking
>>> tokenized_prompts = tokenizer(["your prompts here"])
>>> output_sequences = model.generate(
... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True,
... )
>>> watermarked_text = tokenizer.batch_decode(output_sequences)
```
"""
def __init__(
self,
ngram_len: int,
keys: List[int],
context_history_size: int = 1024,
sampling_table_seed: int = 0,
sampling_table_size: int = 2**16,
skip_first_ngram_calls: bool = False,
debug_mode: bool = False,
):
self.ngram_len = ngram_len
self.keys = keys
self.sampling_table_size = sampling_table_size
self.sampling_table_seed = sampling_table_seed
self.context_history_size = context_history_size
self.skip_first_ngram_calls = skip_first_ngram_calls
self.debug_mode = debug_mode
def validate(self):
watermark_missing_arg_msg = (
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
if self.sampling_table_size > 2**24:
raise ValueError(
watermark_missing_arg_msg.format(
key="sampling_table_size",
correct_value="< 2**24",
found_value=self.sampling_table_size,
),
)
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
return SynthIDTextWatermarkLogitsProcessor(
ngram_len=self.ngram_len,
keys=self.keys,
sampling_table_size=self.sampling_table_size,
sampling_table_seed=self.sampling_table_seed,
context_history_size=self.context_history_size,
device=device,
skip_first_ngram_calls=self.skip_first_ngram_calls,
debug_mode=self.debug_mode,
)

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team
# Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -2460,6 +2460,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
final_greenlist.append(greedy_predictions[i])
return torch.tensor(final_greenlist, device=input_seq.device)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[-1] < self.context_width:
logger.warning(
@ -2477,3 +2478,478 @@ class WatermarkLogitsProcessor(LogitsProcessor):
scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias
return scores_processed
class SynthIDTextWatermarkState:
"""SynthID watermarking state."""
def __init__(
self,
batch_size: int,
ngram_len: int,
context_history_size: int,
device: torch.device,
):
"""Initializes the state.
Args:
batch_size (`int`): Batch size.
ngram_len (`int`): Ngram length.
context_history_size (`int`): Size of the tensor to keep track of seen contexts.
device (`int`): Device to use.
"""
self.context = torch.zeros(
(batch_size, ngram_len - 1),
dtype=torch.int64,
device=device,
)
self.context_history = torch.zeros(
(batch_size, context_history_size),
dtype=torch.int64,
device=device,
)
self.num_calls = 0
class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
r"""
Logits processor that implements watermarking techniques for text generation models.
This class facilitates the application of SynthID text watermarking, a method for embedding imperceptible signals
into generated text to aid in detecting synthetic content. It operates by subtly manipulating the probabilities of
token selection during text generation in a manner that can be reliably recovered later for verification.
Key Features:
* **State Management:** Maintains internal state to track token sequences and generate watermarking keys
dynamically.
* **Key Generation:** Computes hashes based on token sequences and watermarking parameters to create unique keys
for each position.
* **G-Value Sampling:** Employs a pre-computed sampling table to sample watermarking values (g-values) based on
the generated keys.
* **Score Adjustment:** Applies calculated g-values to modify token probabilities during generation, embedding the
watermark.
* **Context Repetition Handling:** Incorporates logic to avoid watermarking tokens in repeated contexts,
preserving naturalness.
* **EOS Token Masking:** Supports masking end-of-sentence tokens to prevent their inclusion in watermarking
calculations.
* **Utility Functions:** Provides functions to compute g-values directly, check for context repetition, create
EOS token masks, and estimate expected mean g-values.
Refer to paper url: https://www.nature.com/articles/s41586-024-08025-4 for more details around this.
Args:
ngram_len (`int`):
Ngram length.
keys (`List[int]`):
A sequence of watermarking keys, one for each depth.
sampling_table_size (`int`):
Size of the sampling table.
sampling_table_seed (`int`):
Random seed to generate the sampling table.
context_history_size (`int`):
Size of the tensor to keep track of seen contexts.
device (`torch.device`):
Device to use.
skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
Whether to skip first ngram calls.
debug_mode (`bool`, optional, *optional*, defaults to `False`):
Logits are modified to uniform one got before watermarking modification is applied. This is to test the
implementation.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it')
>>> # SynthID Text configuration
>>> watermarking_config = SynthIDTextWatermarkingConfig(
... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
... ngram_len=5,
... )
>>> # Generation with watermarking
>>> tokenized_prompts = tokenizer(["your prompts here"])
>>> output_sequences = model.generate(
... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True,
... )
>>> watermarked_text = tokenizer.batch_decode(output_sequences)
```
"""
def __init__(
self,
ngram_len: int,
keys: List[int],
sampling_table_size: int,
sampling_table_seed: int,
context_history_size: int,
device: torch.device,
skip_first_ngram_calls: bool = False,
debug_mode: bool = False,
):
self.ngram_len = ngram_len
self.keys = torch.tensor(keys, device=device)
generator = torch.Generator(device=device).manual_seed(sampling_table_seed)
# A random sampling table is pre-computed and modulo table size is applied to map from a hash of ngram keys to
# g values, this is similar to the hashtable implementation used in
# https://github.com/facebookresearch/three_bricks. We note that the hashing employed in this repository is
# different from that used to watermark the Gemini App, and hence the detectors trained based on the
# hashing in this repository will not transfer to text generated by the Gemini App.
self.sampling_table = torch.randint(
low=0,
high=2,
size=(sampling_table_size,),
generator=generator,
device=device,
)
self.context_history_size = context_history_size
self.device = device
self.state = None
self.skip_first_ngram_calls = skip_first_ngram_calls
self.debug_mode = debug_mode
def _init_state(self, batch_size: int):
"""Initializes the state."""
self.state = SynthIDTextWatermarkState(
batch_size=batch_size,
ngram_len=self.ngram_len,
context_history_size=self.context_history_size,
device=self.device,
)
def update_scores(self, scores: torch.FloatTensor, g_values: torch.FloatTensor) -> torch.FloatTensor:
"""Updates scores using the g values.
We assume that the scores are in the log space.
Args:
scores (`torch.FloatTensor`): Scores (batch_size, vocab_size).
g_values (`torch.FloatTensor`): G valus (batch_size, vocab_size, depth).
Returns:
Updated scores (batch_size, vocab_size).
"""
_, _, depth = g_values.shape
probs = torch.softmax(scores, dim=1)
for i in range(depth):
g_values_at_depth = g_values[:, :, i]
g_mass_at_depth = (g_values_at_depth * probs).sum(axis=1, keepdims=True)
probs = probs * (1 + g_values_at_depth - g_mass_at_depth)
log_probs = torch.log(probs)
log_probs = torch.where(torch.isfinite(log_probs), log_probs, torch.finfo(log_probs.dtype).min)
return log_probs
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self._check_input_ids_shape(input_ids)
batch_size, vocab_size = scores.shape
if self.debug_mode:
scores = torch.ones_like(scores)
# Currently indices is just a arange to compute watermarking on the desnse logits.
all_indices = torch.stack([torch.arange(vocab_size, device=self.device) for _ in range(batch_size)])
if self.state is None:
# Initialize watermarking state if it does not exist.
self._init_state(batch_size)
else:
# Append last input id (which is the input id added in last call) to the
# previous context so we have the context to be used for current
# watermarking.
self.state.context = torch.concat(
(self.state.context, input_ids[:, -1:]),
dim=1,
)
self.state.context = self.state.context[:, 1:]
if self.state is None:
raise ValueError("self.state can't be None! Call `self._init_state` to initialize the state.")
self.state.num_calls += 1
# Don't watermark the first ngram_len - 1 tokens if set.
if self.skip_first_ngram_calls and self.state.num_calls < self.ngram_len:
return scores
# 2. Generate random keys for each ngram key combination.
ngram_keys, hash_result_with_just_context = self._compute_keys(self.state.context, all_indices)
# ngram_keys shape [batch_size, top_k, depth]
# 3. Sample g values.
g_values = self.sample_g_values(ngram_keys)
# g_values shape [batch_size, top_k, depth]
# 4. Modify scores.
updated_scores = self.update_scores(scores, g_values)
# updated scores shape [batch_size, top_k]
# 5. Check if the current watermarking context was previously used, if yes skip watermarking.
hash_result_with_just_context = hash_result_with_just_context[:, None]
is_repeated_context = (self.state.context_history == hash_result_with_just_context).any(
dim=1,
keepdim=True,
)
self.state.context_history = torch.concat(
(hash_result_with_just_context, self.state.context_history),
dim=1,
)[:, :-1]
updated_watermarked_scores = torch.where(
is_repeated_context,
input=scores,
other=updated_scores,
)
return updated_watermarked_scores
def accumulate_hash(
self,
current_hash: torch.LongTensor,
data: torch.LongTensor,
multiplier: int = 6364136223846793005,
increment: int = 1,
) -> torch.LongTensor:
"""
Accumulate hash of data on current hash.
Method uses adapted linear congruential generator with newlib/musl parameters.
This function has following property -
f(x, data[T]) = f(f(x, data[:T - 1]), data[T])
This function expects current_hash.shape and data.shape[:-1] to
match/broadcastable.
Args:
current_hash (`torch.LongTensor`):
(shape,)
data (`torch.LongTensor`):
(shape, tensor_len)
multiplier (`int`, optional, *optional*, defaults to 6364136223846793005):
multiplier of linear congruential generator
increment (`int`, optional, *optional*, defaults to 1):
increment of linear congruential generator
Returns:
updated hash (shape,)
"""
for i in range(data.shape[-1]):
current_hash = torch.add(current_hash, data[..., i])
current_hash = torch.mul(current_hash, multiplier)
current_hash = torch.add(current_hash, increment)
return current_hash
def compute_ngram_keys(self, ngrams: torch.LongTensor) -> torch.LongTensor:
"""Computes random keys for each ngram and depth.
Args:
ngrams (`torch.LongTensor`):
Ngrams (batch_size, num_ngrams, ngram_len).
Returns:
ngram keys (batch_size, num_ngrams, depth).
"""
if len(ngrams.shape) != 3:
raise ValueError(
"Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but" f" is {ngrams.shape}"
)
if ngrams.shape[2] != self.ngram_len:
raise ValueError(
"Ngrams should be of shape (batch_size, num_ngrams, ngram_len),"
f" where ngram_len is {self.ngram_len}, but is {ngrams.shape}"
)
batch_size, _, _ = ngrams.shape
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
# hash_result shape [batch_size,]
# ngrams shape [batch_size, num_ngrams, ngram_len]
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(hash_result, ngrams)
# hash_result shape [batch_size, num_ngrams]
keys = self.keys[None, None, :, None]
# hash_result shape [batch_size, num_ngrams]
# keys shape [1, 1, depth, 1]
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys)
# hash_result shape [batch_size, num_ngrams, depth]
return hash_result
def _compute_keys(
self, n_minus_1_grams: torch.LongTensor, indices: torch.LongTensor
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""Computes random keys for each ngram and depth.
Args:
n_minus_1_grams (`torch.LongTensor`):
Ngrams (batch_size, ngram_len - 1).
indices (`torch.LongTensor`):
indices of the continuations (batch_size, num_indices)
Returns:
Ngram keys (batch_size, num_indices, depth).
"""
batch_size, _ = n_minus_1_grams.shape
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
# First hash n_minus_1 gram, for each batch entry we have a single
# n_minus_1 gram context.
# hash_result shape [batch_size]
# n_minus_1_gram shape [batch_size, ngram_len - 1]
hash_result_with_just_context = self.accumulate_hash(hash_result, n_minus_1_grams)
# hash_result shape [batch_size,]
# Indices is of shape [batch_size, num_indices], so we make it
# [batch_size, num_indices, 1] so we can vmap over num_indices dim.
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(
hash_result_with_just_context, indices[:, :, None]
)
# hash_result shape [batch_size, num_indices]
# Basically we have a hash for each batch entry and each indices
# Now we add watermarking keys to this hash.
# keys are of shape [depth,]
# We add batch, num_indices and data dimension to this making it
# [1, 1, depth, 1].
# So we can vmap over the depth dimension for compute_hash
keys = self.keys[None, None, :, None]
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys)
# hash_result shape should be [batch_size, num_indices, depth]
return hash_result, hash_result_with_just_context
def sample_g_values(self, ngram_keys: torch.LongTensor) -> torch.LongTensor:
"""
Samples g values from Bernoulli distribution.
It is not possible to pass random keys in a vectorized way in torch. Instead
we pre-compute a random sampling table, and use apply modulo table size to
map from ngram keys (int64) to g values.
Args:
ngram_keys (`torch.LongTensor`):
Random keys (batch_size, num_ngrams, depth).
Returns:
G values (batch_size, num_ngrams, depth).
"""
(sampling_table_size,) = self.sampling_table.shape
sampling_table = self.sampling_table.reshape((1, 1, sampling_table_size))
ngram_keys = ngram_keys % sampling_table_size
return torch.take_along_dim(sampling_table, indices=ngram_keys, dim=2)
def _check_input_ids_shape(self, input_ids: torch.LongTensor):
"""Checks the shape of input ids."""
if len(input_ids.shape) != 2:
raise ValueError("Input ids should be of shape (batch_size, input_len), but is" f" {input_ids.shape}")
def compute_g_values(self, input_ids: torch.LongTensor) -> torch.LongTensor:
"""
Computes g values for each ngram from the given sequence of tokens.
Args:
input_ids (`torch.LongTensor`):
Input token ids (batch_size, input_len).
Returns:
G values (batch_size, input_len - (ngram_len - 1), depth).
"""
self._check_input_ids_shape(input_ids)
ngrams = input_ids.unfold(dimension=1, size=self.ngram_len, step=1)
ngram_keys = self.compute_ngram_keys(ngrams)
return self.sample_g_values(ngram_keys)
def compute_context_repetition_mask(self, input_ids: torch.LongTensor) -> torch.LongTensor:
"""
Computes repetition mask.
0 and 1 stand for repeated and not repeated context n-1 grams respectively.
Args:
input_ids (`torch.LongTensor`):
Input token ids (batch_size, input_len).
Returns:
Repetitions mask (batch_size, input_len - (ngram_len - 1)).
"""
self._check_input_ids_shape(input_ids)
batch_size, _ = input_ids.shape
state = SynthIDTextWatermarkState(
batch_size=batch_size,
ngram_len=self.ngram_len,
context_history_size=self.context_history_size,
device=self.device,
)
contexts = input_ids[:, :-1].unfold(
dimension=1,
size=self.ngram_len - 1,
step=1,
)
_, num_contexts, _ = contexts.shape
are_repeated_contexts = []
for i in range(num_contexts):
context = contexts[:, i, :]
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
context_hash = self.accumulate_hash(hash_result, context)[:, None]
is_repeated_context = (state.context_history == context_hash).any(
dim=1,
keepdim=True,
)
are_repeated_contexts.append(is_repeated_context)
state.context_history = torch.concat(
(context_hash, state.context_history),
dim=1,
)[:, :-1]
are_repeated_contexts = torch.concat(are_repeated_contexts, dim=1)
return torch.logical_not(are_repeated_contexts)
def compute_eos_token_mask(self, input_ids: torch.LongTensor, eos_token_id: int) -> torch.LongTensor:
"""
Computes repetitions mask.
1 stands for ngrams that don't contain EOS tokens and vice versa.
Args:
input_ids (`torch.LongTensor`):
Input token ids (batch_size, input_len).
eos_token_id (`int`):
EOS token ID.
Returns:
EOS token mask (batch_size, input_len).
"""
self._check_input_ids_shape(input_ids)
noneos_masks = []
all_eos_equated = input_ids == eos_token_id
for eos_equated in all_eos_equated:
nonzero_idx = torch.nonzero(eos_equated)
noneos_mask = torch.ones_like(eos_equated)
if nonzero_idx.shape[0] != 0:
noneos_mask[nonzero_idx[0][0] :] = 0
noneos_masks.append(noneos_mask)
return torch.stack(noneos_masks, dim=0)
def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) -> float:
"""
Compute expected mean g-value after watermarking, assuming uniform LM dist.
This is the theoretical expected value for single-layer watermarking.
Args:
vocab_size (`int`):
The size of the vocabulary.
coinflip_prob arg_name (`float`, *optional*, defaults to 0.5):
Probability of 1 in boolean prf.
Returns:
The expected mean g-value for watermarked text.
"""
return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))

View File

@ -92,7 +92,6 @@ from .logits_process import (
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from .stopping_criteria import (
ConfidenceCriteria,
@ -1011,15 +1010,7 @@ class GenerationMixin:
)
if generation_config.watermarking_config is not None:
processors.append(
WatermarkLogitsProcessor(
vocab_size=self.config.vocab_size,
device=device,
greenlist_ratio=generation_config.watermarking_config.greenlist_ratio,
bias=generation_config.watermarking_config.bias,
hashing_key=generation_config.watermarking_config.hashing_key,
seeding_scheme=generation_config.watermarking_config.seeding_scheme,
context_width=generation_config.watermarking_config.context_width,
)
generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
)
# TODO (joao): find a strategy to specify the order of the processors

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team
# Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,19 +16,22 @@
import collections
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from torch.nn import BCELoss
from ..configuration_utils import PretrainedConfig
from ..utils import is_torch_available, logging
from .configuration_utils import WatermarkingConfig
from ..modeling_utils import PreTrainedModel
from ..utils import ModelOutput, is_torch_available, logging
from .configuration_utils import PretrainedConfig, WatermarkingConfig
if is_torch_available():
import torch
from .logits_process import WatermarkLogitsProcessor
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
logger = logging.get_logger(__name__)
@ -237,3 +240,310 @@ class WatermarkDetector:
confidence=confidence,
)
return prediction
class BayesianDetectorConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`BayesianDetectorModel`]. It is used to
instantiate a Bayesian Detector model according to the specified arguments.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
watermarking_depth (`int`, *optional*):
The number of tournament layers.
base_rate (`float1`, *optional*, defaults to 0.5):
Prior probability P(w) that a text is watermarked.
"""
def __init__(self, watermarking_depth: int = None, base_rate: float = 0.5, **kwargs):
self.watermarking_depth = watermarking_depth
self.base_rate = base_rate
# These can be set later to store information about this detector.
self.model_name = None
self.watermarking_config = None
super().__init__(**kwargs)
def set_detector_information(self, model_name, watermarking_config):
self.model_name = model_name
self.watermarking_config = watermarking_config
@dataclass
class BayesianWatermarkDetectorModelOutput(ModelOutput):
"""
Base class for outputs of models predicting if the text is watermarked.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss.
posterior_probabilities (`torch.FloatTensor` of shape `(1,)`):
Multiple choice classification loss.
"""
loss: Optional[torch.FloatTensor] = None
posterior_probabilities: Optional[torch.FloatTensor] = None
class BayesianDetectorWatermarkedLikelihood(nn.Module):
"""Watermarked likelihood model for binary-valued g-values.
This takes in g-values and returns p(g_values|watermarked).
"""
def __init__(self, watermarking_depth: int):
"""Initializes the model parameters."""
super().__init__()
self.watermarking_depth = watermarking_depth
self.beta = torch.nn.Parameter(-2.5 + 0.001 * torch.randn(1, 1, watermarking_depth))
self.delta = torch.nn.Parameter(0.001 * torch.randn(1, 1, self.watermarking_depth, watermarking_depth))
def _compute_latents(self, g_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes the unique token probability distribution given g-values.
Args:
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
PRF values.
Returns:
p_one_unique_token and p_two_unique_tokens, both of shape
[batch_size, seq_len, watermarking_depth]. p_one_unique_token[i,t,l]
gives the probability of there being one unique token in a tournament
match on layer l, on timestep t, for batch item i.
p_one_unique_token[i,t,l] + p_two_unique_token[i,t,l] = 1.
"""
# Tile g-values to produce feature vectors for predicting the latents
# for each layer in the tournament; our model for the latents psi is a
# logistic regression model psi = sigmoid(delta * x + beta).
# [batch_size, seq_len, watermarking_depth, watermarking_depth]
x = torch.repeat_interleave(torch.unsqueeze(g_values, dim=-2), self.watermarking_depth, axis=-2)
# mask all elements above -1 diagonal for autoregressive factorization
x = torch.tril(x, diagonal=-1)
# [batch_size, seq_len, watermarking_depth]
# (i, j, k, l) x (i, j, k, l) -> (i, j, k) einsum equivalent
logits = (self.delta[..., None, :] @ x.type(self.delta.dtype)[..., None]).squeeze() + self.beta
p_two_unique_tokens = torch.sigmoid(logits)
p_one_unique_token = 1 - p_two_unique_tokens
return p_one_unique_token, p_two_unique_tokens
def forward(self, g_values: torch.Tensor) -> torch.Tensor:
"""Computes the likelihoods P(g_values|watermarked).
Args:
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
g-values (values 0 or 1)
Returns:
p(g_values|watermarked) of shape [batch_size, seq_len, watermarking_depth].
"""
p_one_unique_token, p_two_unique_tokens = self._compute_latents(g_values)
# P(g_tl | watermarked) is equal to
# 0.5 * [ (g_tl+0.5) * p_two_unique_tokens + p_one_unique_token].
return 0.5 * ((g_values + 0.5) * p_two_unique_tokens + p_one_unique_token)
class BayesianDetectorModel(PreTrainedModel):
r"""
Bayesian classifier for watermark detection.
This detector uses Bayes' rule to compute a watermarking score, which is the sigmoid of the log of ratio of the
posterior probabilities P(watermarked|g_values) and P(unwatermarked|g_values). Please see the section on
BayesianScore in the paper for further details.
Paper URL: https://www.nature.com/articles/s41586-024-08025-4
Note that this detector only works with non-distortionary Tournament-based watermarking using the Bernoulli(0.5)
g-value distribution.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`BayesianDetectorConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
config_class = BayesianDetectorConfig
base_model_prefix = "model"
def __init__(self, config):
super().__init__(config)
self.watermarking_depth = config.watermarking_depth
self.base_rate = config.base_rate
self.likelihood_model_watermarked = BayesianDetectorWatermarkedLikelihood(
watermarking_depth=self.watermarking_depth
)
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Parameter):
module.weight.data.normal_(mean=0.0, std=0.02)
def _compute_posterior(
self,
likelihoods_watermarked: torch.Tensor,
likelihoods_unwatermarked: torch.Tensor,
mask: torch.Tensor,
prior: float,
) -> torch.Tensor:
"""
Compute posterior P(w|g) given likelihoods, mask and prior.
Args:
likelihoods_watermarked (`torch.Tensor` of shape `(batch, length, depth)`):
Likelihoods P(g_values|watermarked) of g-values under watermarked model.
likelihoods_unwatermarked (`torch.Tensor` of shape `(batch, length, depth)`):
Likelihoods P(g_values|unwatermarked) of g-values under unwatermarked model.
mask (`torch.Tensor` of shape `(batch, length)`):
A binary array indicating which g-values should be used. g-values with mask value 0 are discarded.
prior (`float`):
the prior probability P(w) that the text is watermarked.
Returns:
Posterior probability P(watermarked|g_values), shape [batch].
"""
mask = torch.unsqueeze(mask, dim=-1)
prior = torch.clamp(prior, min=1e-5, max=1 - 1e-5)
log_likelihoods_watermarked = torch.log(torch.clamp(likelihoods_watermarked, min=1e-30, max=float("inf")))
log_likelihoods_unwatermarked = torch.log(torch.clamp(likelihoods_unwatermarked, min=1e-30, max=float("inf")))
log_odds = log_likelihoods_watermarked - log_likelihoods_unwatermarked
# Sum relative surprisals (log odds) across all token positions and layers.
relative_surprisal_likelihood = torch.einsum("i...->i", log_odds * mask)
# Compute the relative surprisal prior
relative_surprisal_prior = torch.log(prior) - torch.log(1 - prior)
# Combine prior and likelihood.
# [batch_size]
relative_surprisal = relative_surprisal_prior + relative_surprisal_likelihood
# Compute the posterior probability P(w|g) = sigmoid(relative_surprisal).
return torch.sigmoid(relative_surprisal)
def forward(
self,
g_values: torch.Tensor,
mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
loss_batch_weight=1,
return_dict=False,
) -> BayesianWatermarkDetectorModelOutput:
"""
Computes the watermarked posterior P(watermarked|g_values).
Args:
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth, ...)`):
g-values (with values 0 or 1)
mask:
A binary array shape [batch_size, seq_len] indicating which g-values should be used. g-values with mask
value 0 are discarded.
Returns:
p(watermarked | g_values), of shape [batch_size].
"""
likelihoods_watermarked = self.likelihood_model_watermarked(g_values)
likelihoods_unwatermarked = 0.5 * torch.ones_like(g_values)
out = self._compute_posterior(
likelihoods_watermarked=likelihoods_watermarked,
likelihoods_unwatermarked=likelihoods_unwatermarked,
mask=mask,
prior=self.prior,
)
loss = None
if labels is not None:
loss_fct = BCELoss()
loss_unwweight = torch.sum(self.likelihood_model_watermarked.delta**2)
loss_weight = loss_unwweight * loss_batch_weight
loss = loss_fct(torch.clamp(out, 1e-5, 1 - 1e-5), labels) + loss_weight
if not return_dict:
return (out,) if loss is None else (out, loss)
return BayesianWatermarkDetectorModelOutput(loss=loss, posterior_probabilities=out)
class SynthIDTextWatermarkDetector:
r"""
SynthID text watermark detector class.
This class has to be initialized with the trained bayesian detector module check script
in examples/synthid_text/detector_training.py for example in training/saving/loading this
detector module. The folder also showcases example use case of this detector.
Parameters:
detector_module ([`BayesianDetectorModel`]):
Bayesian detector module object initialized with parameters.
Check examples/research_projects/synthid_text/detector_training.py for usage.
logits_processor (`SynthIDTextWatermarkLogitsProcessor`):
The logits processor used for watermarking.
tokenizer (`Any`):
The tokenizer used for the model.
Examples:
```python
>>> from transformers import (
... AutoTokenizer, BayesianDetectorModel, SynthIDTextWatermarkLogitsProcessor, SynthIDTextWatermarkDetector
... )
>>> # Load the detector. See examples/research_projects/synthid_text for training a detector.
>>> detector_model = BayesianDetectorModel.from_pretrained("joaogante/dummy_synthid_detector")
>>> logits_processor = SynthIDTextWatermarkLogitsProcessor(
... **detector_model.config.watermarking_config, device="cpu"
... )
>>> tokenizer = AutoTokenizer.from_pretrained(detector_model.config.model_name)
>>> detector = SynthIDTextWatermarkDetector(detector_model, logits_processor, tokenizer)
>>> # Test whether a certain string is watermarked
>>> test_input = tokenizer(["This is a test input"], return_tensors="pt")
>>> is_watermarked = detector(test_input.input_ids)
```
"""
def __init__(
self,
detector_module: BayesianDetectorModel,
logits_processor: SynthIDTextWatermarkLogitsProcessor,
tokenizer: Any,
):
self.detector_module = detector_module
self.logits_processor = logits_processor
self.tokenizer = tokenizer
def __call__(self, tokenized_outputs: torch.Tensor):
# eos mask is computed, skip first ngram_len - 1 tokens
# eos_mask will be of shape [batch_size, output_len]
eos_token_mask = self.logits_processor.compute_eos_token_mask(
input_ids=tokenized_outputs,
eos_token_id=self.tokenizer.eos_token_id,
)[:, self.logits_processor.ngram_len - 1 :]
# context repetition mask is computed
context_repetition_mask = self.logits_processor.compute_context_repetition_mask(
input_ids=tokenized_outputs,
)
# context repitition mask shape [batch_size, output_len - (ngram_len - 1)]
combined_mask = context_repetition_mask * eos_token_mask
g_values = self.logits_processor.compute_g_values(
input_ids=tokenized_outputs,
)
# g values shape [batch_size, output_len - (ngram_len - 1), depth]
return self.detector_module(g_values, combined_mask)

View File

@ -191,6 +191,20 @@ class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class BayesianDetectorConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BayesianDetectorModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeamScorer(metaclass=DummyObject):
_backends = ["torch"]
@ -457,6 +471,27 @@ class SuppressTokensLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class SynthIDTextWatermarkDetector(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SynthIDTextWatermarkingConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SynthIDTextWatermarkLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -16,6 +16,7 @@
import unittest
from typing import List, Union
import numpy as np
from parameterized import parameterized
from transformers import is_torch_available
@ -48,6 +49,7 @@ if is_torch_available():
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SynthIDTextWatermarkLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
@ -975,3 +977,187 @@ class LogitsProcessorTest(unittest.TestCase):
scores_wo_bias = scores[:, -1].clone()
out = watermark(input_ids=input_ids, scores=scores)
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
@parameterized.expand([(5, 3, 10000), (10, 5, 1000)])
def test_synthidtext_watermarking_processor_bias_uniformity(self, ngram_len, num_layers, vocab_size):
"""Test SynthID watermarked distribution bias uniformity over iterations."""
torch.manual_seed(0)
np.random.seed(0)
watermarking_config = {
"ngram_len": ngram_len,
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 512,
"device": torch_device,
}
batch_size = 100000
ngrams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, ngram_len),
device=torch_device,
)
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
g_values = logits_processor.compute_g_values(ngrams)
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=0))
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.01)
@parameterized.expand([(10000, 3), (1000, 20)])
def test_synthidtext_watermark_processor_bias_uniformity_across_vocab(self, vocab_size, num_layers):
"""Test SynthID watermarked distribution bias uniformity over vocabs of the model."""
batch_size = 1000
ngram_len = 5
torch.manual_seed(0)
np.random.seed(0)
watermarking_config = {
"ngram_len": ngram_len,
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 512,
"device": torch_device,
}
n_minus_1_grams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, watermarking_config["ngram_len"] - 1),
device=torch_device,
)
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
ngram_keys, _ = logits_processor._compute_keys(
n_minus_1_grams,
torch.stack([torch.arange(vocab_size, device=torch_device) for _ in range(batch_size)]),
)
g_values = logits_processor.sample_g_values(ngram_keys)
# g_values shape should be [batch_size, vocab_size, num_layers]
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1))
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001)
@parameterized.expand([(2, "uniform"), (10, "uniform"), (2, "random"), (10, "random")])
def test_synthidtext_watermark_processor_distributional_convergence(self, vocab_size, logits_type):
"""Check if watermarked distribution converges to unwatermarked logits distribution."""
batch_size = 1500
num_keys = 1000
updated_softmaxes = 0
np.random.seed(0)
torch.manual_seed(0)
if logits_type == "uniform":
fixed_logits = torch.ones((batch_size, vocab_size), device=torch_device)
elif logits_type == "random":
fixed_logits = torch.rand(
(
1,
vocab_size,
),
device=torch_device,
)
fixed_logits = fixed_logits.repeat(batch_size, 1)
else:
raise ValueError(f"Unrecognized logits_type {logits_type}")
for _ in range(num_keys):
watermarking_config = {
"ngram_len": 5,
"keys": np.random.randint(0, 10**9, size=(1,), dtype=np.int64),
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 1024,
"device": torch_device,
}
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
ngrams = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, watermarking_config["ngram_len"]),
device=torch_device,
)
# Insert ngram-1 into logit_processor state.
for idx in range(watermarking_config["ngram_len"] - 1):
_ = logits_processor(ngrams[:, :idx], fixed_logits)
updated_scores = logits_processor(ngrams, fixed_logits)
updated_softmaxes += torch.nn.functional.softmax(updated_scores, dim=1).cpu().numpy()
updated_softmaxes = np.mean(updated_softmaxes, axis=0) / num_keys
is_close = torch.all(
torch.isclose(
torch.tensor(updated_softmaxes, device=torch_device),
torch.nn.Softmax()(fixed_logits[0]), # Take any batch entry, all are same.
atol=1e-3,
rtol=0,
)
)
self.assertTrue(is_close)
@parameterized.expand([(2, 10, 1, 0.01), (100, 5, 1, 0.01), (100, 10, 2, 0.02)])
def test_synthidtext_watermark_processor_bias_test(self, vocab_size, ngram_len, num_layers, atol):
"""Test SynthID watermarking bias matches theoretical value."""
batch_size = 20000
generator = torch.Generator(device=torch_device).manual_seed(0)
np.random.seed(0)
keys = [np.random.randint(0, 10**9) for _ in range(num_layers)]
# Use 10**9 rather than vocab_size to ensure variety in (n-1)-grams.
context = torch.randint(
low=0,
high=10**9,
size=(batch_size, ngram_len - 1),
dtype=torch.int64,
generator=generator,
device=torch_device,
)
context_history_size = 1024
logits_processor = SynthIDTextWatermarkLogitsProcessor(
ngram_len=ngram_len,
keys=keys,
sampling_table_size=2**16,
sampling_table_seed=0,
context_history_size=context_history_size,
device=torch_device,
)
scores = torch.ones(
(batch_size, vocab_size),
dtype=torch.float64,
device=torch_device,
)
# Init state of the logits processor.
logits_processor(context, scores)
# insert context into the state.
for idx in range(1, ngram_len - 1):
_ = logits_processor(context[:, :idx], scores)
updated_scores = logits_processor(context, scores)
probs = torch.nn.functional.softmax(updated_scores, dim=1)
generator = torch.Generator(device=torch_device).manual_seed(0)
next_tokens = torch.multinomial(
probs,
num_samples=1,
generator=generator,
)
ngrams = torch.concat((context, next_tokens), dim=1)
g_values = logits_processor.compute_g_values(ngrams)
mean_g_values = g_values.mean(dtype=torch.float64, dim=(0, 1))
expected_mean_g_value = logits_processor.expected_mean_g_value(
vocab_size=vocab_size,
)
is_close = torch.all(
torch.isclose(
mean_g_values,
torch.tensor(expected_mean_g_value, dtype=torch.float64, device=torch_device),
atol=atol,
rtol=0,
)
)
self.assertTrue(is_close)

View File

@ -84,6 +84,7 @@ if is_torch_available():
SampleEncoderDecoderOutput,
StoppingCriteria,
StoppingCriteriaList,
SynthIDTextWatermarkingConfig,
WatermarkDetector,
WatermarkingConfig,
)
@ -2517,9 +2518,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow
def test_watermark_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
def test_green_red_watermark_generation(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
input_len = model_inputs["input_ids"].shape[-1]
@ -2548,6 +2549,61 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
self.assertListEqual(detection_out.prediction.tolist(), [False])
"""Check the mean bias inserted by the watermarking algorithm."""
@slow
def test_synthid_text_watermark_generation_mean_expected_bias(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
input_len = 5
batch_size = 200
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
watermark_config = SynthIDTextWatermarkingConfig(keys=[10, 20], ngram_len=5, debug_mode=True)
logits_processor = watermark_config.construct_processor(model.config.vocab_size, torch_device)
mean_g_values_repeats = []
for _ in range(40):
input_ids = torch.zeros(
(batch_size, input_len),
dtype=torch.int64,
device=torch_device,
)
model_inputs = {
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids, device=torch_device),
}
output = model.generate(
**model_inputs, watermarking_config=watermark_config, do_sample=True, max_length=500, top_k=1000
)
g_values = logits_processor.compute_g_values(input_ids=output[:, input_len:])
context_repetition_mask = logits_processor.compute_context_repetition_mask(
input_ids=output[:, input_len:],
).unsqueeze(dim=2)
mean_g_values = torch.masked.mean(
g_values,
mask=context_repetition_mask,
dim=0,
keepdim=True,
dtype=torch.float64,
)
mean_g_values_repeats.append(mean_g_values)
mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0)
expected_mean_g_value = logits_processor.expected_mean_g_value(
vocab_size=model.config.vocab_size,
)
atol = 0.03
is_close = torch.isclose(
mean_g_values,
torch.tensor(expected_mean_g_value, dtype=torch.float64),
atol=atol,
rtol=0,
)
self.assertTrue(torch.all(is_close))
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer