mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 11:08:23 +06:00
Add SynthID (watermerking by Google DeepMind) (#34350)
* Add SynthIDTextWatermarkLogitsProcessor * esolving comments. * Resolving comments. * esolving commits, * Improving SynthIDWatermark tests. * switch to PT version * detector as pretrained model + style * update training + style * rebase * Update logits_process.py * Improving SynthIDWatermark tests. * Shift detector training to wikitext negatives and stabilize with lower learning rate. * Clean up. * in for 7B * cleanup * upport python 3.8. * README and final cleanup. * HF Hub upload and initiaze. * Update requirements for synthid_text. * Adding SynthIDTextWatermarkDetector. * Detector testing. * Documentation changes. * Copyrights fix. * Fix detector api. * ironing out errors * ironing out errors * training checks * make fixup and make fix-copies * docstrings and add to docs * copyright * BC * test docstrings * move import * protect type hints * top level imports * watermarking example * direct imports * tpr fpr meaning * process_kwargs * SynthIDTextWatermarkingConfig docstring * assert -> exception * example updates * no immutable dict (cant be serialized) * pack fn * einsum equivalent * import order * fix test on gpu * add detector example --------- Co-authored-by: Sumedh Ghaisas <sumedhg@google.com> Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com> Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
parent
e50bf61dec
commit
b0f0c61899
@ -185,6 +185,9 @@ generation.
|
|||||||
[[autodoc]] SuppressTokensLogitsProcessor
|
[[autodoc]] SuppressTokensLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] SynthIDTextWatermarkLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TemperatureLogitsWarper
|
[[autodoc]] TemperatureLogitsWarper
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
@ -418,5 +421,20 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
|
|
||||||
## Watermark Utils
|
## Watermark Utils
|
||||||
|
|
||||||
|
[[autodoc]] WatermarkingConfig
|
||||||
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] WatermarkDetector
|
[[autodoc]] WatermarkDetector
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] BayesianDetectorConfig
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] BayesianDetectorModel
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] SynthIDTextWatermarkingConfig
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] SynthIDTextWatermarkDetector
|
||||||
|
- __call__
|
||||||
|
@ -41,8 +41,6 @@ like token streaming.
|
|||||||
- validate
|
- validate
|
||||||
- get_generation_mode
|
- get_generation_mode
|
||||||
|
|
||||||
[[autodoc]] generation.WatermarkingConfig
|
|
||||||
|
|
||||||
## GenerationMixin
|
## GenerationMixin
|
||||||
|
|
||||||
[[autodoc]] GenerationMixin
|
[[autodoc]] GenerationMixin
|
||||||
|
34
examples/research_projects/synthid_text/README.md
Normal file
34
examples/research_projects/synthid_text/README.md
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# SynthID Text
|
||||||
|
|
||||||
|
This project showcases the use of SynthIDText for watermarking LLMs. The code shown in this repo also
|
||||||
|
demostrates the training of the detector for detecting such watermarked text. This detector can be uploaded onto
|
||||||
|
a private HF hub repo (private for security reasons) and can be initialized again through pretrained model loading also shown in this script.
|
||||||
|
|
||||||
|
See our blog post: https://huggingface.co/blog/synthid-text
|
||||||
|
|
||||||
|
|
||||||
|
## Python version
|
||||||
|
|
||||||
|
User would need python 3.9 to run this example.
|
||||||
|
|
||||||
|
## Installation and running
|
||||||
|
|
||||||
|
Once you install transformers you would need to install requirements for this project through requirements.txt provided in this folder.
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## To run the detector training
|
||||||
|
|
||||||
|
```
|
||||||
|
python detector_training.py --model_name=google/gemma-7b-it
|
||||||
|
```
|
||||||
|
|
||||||
|
Check the script for more parameters are are tunable and check out paper at link
|
||||||
|
https://www.nature.com/articles/s41586-024-08025-4 for more information on these parameters.
|
||||||
|
|
||||||
|
## Caveat
|
||||||
|
|
||||||
|
Make sure to run the training of the detector and the detection on the same hardware
|
||||||
|
CPU, GPU or TPU to get consistent results (we use detecterministic randomness which is hardware dependent).
|
502
examples/research_projects/synthid_text/detector_training.py
Normal file
502
examples/research_projects/synthid_text/detector_training.py
Normal file
@ -0,0 +1,502 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 Google DeepMind.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BayesianDetectorConfig,
|
||||||
|
BayesianDetectorModel,
|
||||||
|
SynthIDTextWatermarkDetector,
|
||||||
|
SynthIDTextWatermarkingConfig,
|
||||||
|
SynthIDTextWatermarkLogitsProcessor,
|
||||||
|
)
|
||||||
|
from utils import (
|
||||||
|
get_tokenized_uwm_outputs,
|
||||||
|
get_tokenized_wm_outputs,
|
||||||
|
process_raw_model_outputs,
|
||||||
|
update_fn_if_fpr_tpr,
|
||||||
|
upload_model_to_hf,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@enum.unique
|
||||||
|
class ValidationMetric(enum.Enum):
|
||||||
|
"""Direction along the z-axis."""
|
||||||
|
|
||||||
|
TPR_AT_FPR = "tpr_at_fpr"
|
||||||
|
CROSS_ENTROPY = "cross_entropy"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
"""Training arguments pertaining to the training loop itself."""
|
||||||
|
|
||||||
|
eval_metric: Optional[str] = dataclasses.field(
|
||||||
|
default=ValidationMetric.TPR_AT_FPR, metadata={"help": "The evaluation metric used."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_detector(
|
||||||
|
detector: torch.nn.Module,
|
||||||
|
g_values: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
watermarked: torch.Tensor,
|
||||||
|
epochs: int = 250,
|
||||||
|
learning_rate: float = 1e-3,
|
||||||
|
minibatch_size: int = 64,
|
||||||
|
seed: int = 0,
|
||||||
|
l2_weight: float = 0.0,
|
||||||
|
shuffle: bool = True,
|
||||||
|
g_values_val: Optional[torch.Tensor] = None,
|
||||||
|
mask_val: Optional[torch.Tensor] = None,
|
||||||
|
watermarked_val: Optional[torch.Tensor] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
|
||||||
|
) -> Tuple[Dict[str, Any], float]:
|
||||||
|
"""Trains a Bayesian detector model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
g_values: g-values of shape [num_train, seq_len, watermarking_depth].
|
||||||
|
mask: A binary array shape [num_train, seq_len] indicating which g-values
|
||||||
|
should be used. g-values with mask value 0 are discarded.
|
||||||
|
watermarked: A binary array of shape [num_train] indicating whether the
|
||||||
|
example is watermarked (0: unwatermarked, 1: watermarked).
|
||||||
|
epochs: Number of epochs to train for.
|
||||||
|
learning_rate: Learning rate for optimizer.
|
||||||
|
minibatch_size: Minibatch size for training. Note that a minibatch
|
||||||
|
requires ~ 32 * minibatch_size * seq_len * watermarked_depth *
|
||||||
|
watermarked_depth bits of memory.
|
||||||
|
seed: Seed for parameter initialization.
|
||||||
|
l2_weight: Weight to apply to L2 regularization for delta parameters.
|
||||||
|
shuffle: Whether to shuffle before training.
|
||||||
|
g_values_val: Validation g-values of shape [num_val, seq_len,
|
||||||
|
watermarking_depth].
|
||||||
|
mask_val: Validation mask of shape [num_val, seq_len].
|
||||||
|
watermarked_val: Validation watermark labels of shape [num_val].
|
||||||
|
verbose: Boolean indicating verbosity of training. If true, the loss will
|
||||||
|
be printed. Defaulted to False.
|
||||||
|
use_tpr_fpr_for_val: Whether to use TPR@FPR=1% as metric for validation.
|
||||||
|
If false, use cross entropy loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of
|
||||||
|
training_history: Training history keyed by epoch number where the
|
||||||
|
values are
|
||||||
|
dictionaries containing the loss, validation loss, and model
|
||||||
|
parameters,
|
||||||
|
keyed by
|
||||||
|
'loss', 'val_loss', and 'params', respectively.
|
||||||
|
min_val_loss: Minimum validation loss achieved during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Set the random seed for reproducibility
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
# Shuffle the data if required
|
||||||
|
if shuffle:
|
||||||
|
indices = torch.randperm(len(g_values))
|
||||||
|
g_values = g_values[indices]
|
||||||
|
mask = mask[indices]
|
||||||
|
watermarked = watermarked[indices]
|
||||||
|
|
||||||
|
# Initialize optimizer
|
||||||
|
optimizer = torch.optim.Adam(detector.parameters(), lr=learning_rate)
|
||||||
|
history = {}
|
||||||
|
min_val_loss = float("inf")
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
losses = []
|
||||||
|
detector.train()
|
||||||
|
num_batches = len(g_values) // minibatch_size
|
||||||
|
for i in range(0, len(g_values), minibatch_size):
|
||||||
|
end = i + minibatch_size
|
||||||
|
if end > len(g_values):
|
||||||
|
break
|
||||||
|
loss_batch_weight = l2_weight / num_batches
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = detector(
|
||||||
|
g_values=g_values[i:end],
|
||||||
|
mask=mask[i:end],
|
||||||
|
labels=watermarked[i:end],
|
||||||
|
loss_batch_weight=loss_batch_weight,
|
||||||
|
)[1]
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
losses.append(loss.item())
|
||||||
|
train_loss = sum(losses) / len(losses)
|
||||||
|
|
||||||
|
val_losses = []
|
||||||
|
if g_values_val is not None:
|
||||||
|
detector.eval()
|
||||||
|
if validation_metric == ValidationMetric.TPR_AT_FPR:
|
||||||
|
val_loss = update_fn_if_fpr_tpr(
|
||||||
|
detector,
|
||||||
|
g_values_val,
|
||||||
|
mask_val,
|
||||||
|
watermarked_val,
|
||||||
|
minibatch_size=minibatch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for i in range(0, len(g_values_val), minibatch_size):
|
||||||
|
end = i + minibatch_size
|
||||||
|
if end > len(g_values_val):
|
||||||
|
break
|
||||||
|
with torch.no_grad():
|
||||||
|
v_loss = detector(
|
||||||
|
g_values=g_values_val[i:end],
|
||||||
|
mask=mask_val[i:end],
|
||||||
|
labels=watermarked_val[i:end],
|
||||||
|
loss_batch_weight=0,
|
||||||
|
)[1]
|
||||||
|
val_losses.append(v_loss.item())
|
||||||
|
val_loss = sum(val_losses) / len(val_losses)
|
||||||
|
|
||||||
|
# Store training history
|
||||||
|
history[epoch + 1] = {"loss": train_loss, "val_loss": val_loss}
|
||||||
|
if verbose:
|
||||||
|
if val_loss is not None:
|
||||||
|
print(f"Epoch {epoch}: loss {loss} (train), {val_loss} (val)")
|
||||||
|
else:
|
||||||
|
print(f"Epoch {epoch}: loss {loss} (train)")
|
||||||
|
|
||||||
|
if val_loss is not None and val_loss < min_val_loss:
|
||||||
|
min_val_loss = val_loss
|
||||||
|
best_val_epoch = epoch
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Best val Epoch: {best_val_epoch}, min_val_loss: {min_val_loss}")
|
||||||
|
|
||||||
|
return history, min_val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train_best_detector(
|
||||||
|
tokenized_wm_outputs: Union[List[np.ndarray], np.ndarray],
|
||||||
|
tokenized_uwm_outputs: Union[List[np.ndarray], np.ndarray],
|
||||||
|
logits_processor: SynthIDTextWatermarkLogitsProcessor,
|
||||||
|
tokenizer: Any,
|
||||||
|
torch_device: torch.device,
|
||||||
|
test_size: float = 0.3,
|
||||||
|
pos_truncation_length: Optional[int] = 200,
|
||||||
|
neg_truncation_length: Optional[int] = 100,
|
||||||
|
max_padded_length: int = 2300,
|
||||||
|
n_epochs: int = 50,
|
||||||
|
learning_rate: float = 2.1e-2,
|
||||||
|
l2_weights: np.ndarray = np.logspace(-3, -2, num=4),
|
||||||
|
verbose: bool = False,
|
||||||
|
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
|
||||||
|
):
|
||||||
|
"""Train and return the best detector given range of hyperparameters.
|
||||||
|
|
||||||
|
In practice, we have found that tuning pos_truncation_length,
|
||||||
|
neg_truncation_length, n_epochs, learning_rate and l2_weights can help
|
||||||
|
improve the performance of the detector. We reccommend tuning these
|
||||||
|
parameters for your data.
|
||||||
|
"""
|
||||||
|
l2_weights = list(l2_weights)
|
||||||
|
|
||||||
|
(
|
||||||
|
train_g_values,
|
||||||
|
train_masks,
|
||||||
|
train_labels,
|
||||||
|
cv_g_values,
|
||||||
|
cv_masks,
|
||||||
|
cv_labels,
|
||||||
|
) = process_raw_model_outputs(
|
||||||
|
logits_processor,
|
||||||
|
tokenizer,
|
||||||
|
pos_truncation_length,
|
||||||
|
neg_truncation_length,
|
||||||
|
max_padded_length,
|
||||||
|
tokenized_wm_outputs,
|
||||||
|
test_size,
|
||||||
|
tokenized_uwm_outputs,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_detector = None
|
||||||
|
lowest_loss = float("inf")
|
||||||
|
val_losses = []
|
||||||
|
for l2_weight in l2_weights:
|
||||||
|
config = BayesianDetectorConfig(watermarking_depth=len(logits_processor.keys))
|
||||||
|
detector = BayesianDetectorModel(config).to(torch_device)
|
||||||
|
_, min_val_loss = train_detector(
|
||||||
|
detector=detector,
|
||||||
|
g_values=train_g_values,
|
||||||
|
mask=train_masks,
|
||||||
|
watermarked=train_labels,
|
||||||
|
g_values_val=cv_g_values,
|
||||||
|
mask_val=cv_masks,
|
||||||
|
watermarked_val=cv_labels,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
l2_weight=l2_weight,
|
||||||
|
epochs=n_epochs,
|
||||||
|
verbose=verbose,
|
||||||
|
validation_metric=validation_metric,
|
||||||
|
)
|
||||||
|
val_losses.append(min_val_loss)
|
||||||
|
if min_val_loss < lowest_loss:
|
||||||
|
lowest_loss = min_val_loss
|
||||||
|
best_detector = detector
|
||||||
|
return best_detector, lowest_loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
default="google/gemma-2b-it",
|
||||||
|
help=("LM model to train the detector for."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help=("Temperature to sample from the model."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_k",
|
||||||
|
type=int,
|
||||||
|
default=40,
|
||||||
|
help=("Top K for sampling."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_p",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help=("Top P for sampling."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_negatives",
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help=("Number of negatives for detector training."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pos_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help=("Batch size of watermarked positives while sampling."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_pos_batch",
|
||||||
|
type=int,
|
||||||
|
default=313,
|
||||||
|
help=("Number of positive batches for training."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--generation_length",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help=("Generation length for sampling."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_model_to_hf_hub",
|
||||||
|
action="store_true",
|
||||||
|
help=("Whether to save the trained model HF hub. By default it will be a private repo."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_from_hf_hub",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Whether to load trained detector model from HF Hub, make sure its the model trained on the same model "
|
||||||
|
"we are loading in the script."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf_hub_model_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=("HF hub model name for loading of saving the model."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_detector_on_prompts",
|
||||||
|
action="store_true",
|
||||||
|
help=("Evaluate detector on a prompt and print probability of watermark."),
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_name = args.model_name
|
||||||
|
temperature = args.temperature
|
||||||
|
top_k = args.top_k
|
||||||
|
top_p = args.top_p
|
||||||
|
num_negatives = args.num_negatives
|
||||||
|
pos_batch_size = args.pos_batch_size
|
||||||
|
num_pos_batch = args.num_pos_batch
|
||||||
|
if num_pos_batch < 10:
|
||||||
|
raise ValueError("--num_pos_batch should be greater than 10.")
|
||||||
|
generation_length = args.generation_length
|
||||||
|
save_model_to_hf_hub = args.save_model_to_hf_hub
|
||||||
|
load_from_hf_hub = args.load_from_hf_hub
|
||||||
|
repo_name = args.hf_hub_model_name
|
||||||
|
eval_detector_on_prompts = args.eval_detector_on_prompts
|
||||||
|
|
||||||
|
NEG_BATCH_SIZE = 32
|
||||||
|
|
||||||
|
# Truncate outputs to this length for training.
|
||||||
|
POS_TRUNCATION_LENGTH = 200
|
||||||
|
NEG_TRUNCATION_LENGTH = 100
|
||||||
|
# Pad trucated outputs to this length for equal shape across all batches.
|
||||||
|
MAX_PADDED_LENGTH = 1000
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
if DEVICE.type not in ("cuda", "tpu"):
|
||||||
|
raise ValueError("We have found the training stable on GPU and TPU, we are working on" " a fix for CPUs")
|
||||||
|
|
||||||
|
model = None
|
||||||
|
if not load_from_hf_hub:
|
||||||
|
# Change this to make your watermark unique. Check documentation in the paper to understand the
|
||||||
|
# impact of these parameters.
|
||||||
|
DEFAULT_WATERMARKING_CONFIG = {
|
||||||
|
"ngram_len": 5, # This corresponds to H=4 context window size in the paper.
|
||||||
|
"keys": [
|
||||||
|
654,
|
||||||
|
400,
|
||||||
|
836,
|
||||||
|
123,
|
||||||
|
340,
|
||||||
|
443,
|
||||||
|
597,
|
||||||
|
160,
|
||||||
|
57,
|
||||||
|
29,
|
||||||
|
590,
|
||||||
|
639,
|
||||||
|
13,
|
||||||
|
715,
|
||||||
|
468,
|
||||||
|
990,
|
||||||
|
966,
|
||||||
|
226,
|
||||||
|
324,
|
||||||
|
585,
|
||||||
|
118,
|
||||||
|
504,
|
||||||
|
421,
|
||||||
|
521,
|
||||||
|
129,
|
||||||
|
669,
|
||||||
|
732,
|
||||||
|
225,
|
||||||
|
90,
|
||||||
|
960,
|
||||||
|
],
|
||||||
|
"sampling_table_size": 2**16,
|
||||||
|
"sampling_table_seed": 0,
|
||||||
|
"context_history_size": 1024,
|
||||||
|
}
|
||||||
|
watermark_config = SynthIDTextWatermarkingConfig(**DEFAULT_WATERMARKING_CONFIG)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
logits_processor = SynthIDTextWatermarkLogitsProcessor(**DEFAULT_WATERMARKING_CONFIG, device=DEVICE)
|
||||||
|
tokenized_wm_outputs = get_tokenized_wm_outputs(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
watermark_config,
|
||||||
|
num_pos_batch,
|
||||||
|
pos_batch_size,
|
||||||
|
temperature,
|
||||||
|
generation_length,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
DEVICE,
|
||||||
|
)
|
||||||
|
tokenized_uwm_outputs = get_tokenized_uwm_outputs(num_negatives, NEG_BATCH_SIZE, tokenizer, DEVICE)
|
||||||
|
|
||||||
|
best_detector, lowest_loss = train_best_detector(
|
||||||
|
tokenized_wm_outputs=tokenized_wm_outputs,
|
||||||
|
tokenized_uwm_outputs=tokenized_uwm_outputs,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
torch_device=DEVICE,
|
||||||
|
test_size=0.3,
|
||||||
|
pos_truncation_length=POS_TRUNCATION_LENGTH,
|
||||||
|
neg_truncation_length=NEG_TRUNCATION_LENGTH,
|
||||||
|
max_padded_length=MAX_PADDED_LENGTH,
|
||||||
|
n_epochs=100,
|
||||||
|
learning_rate=3e-3,
|
||||||
|
l2_weights=[
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
verbose=True,
|
||||||
|
validation_metric=ValidationMetric.TPR_AT_FPR,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if repo_name is None:
|
||||||
|
raise ValueError("When loading from pretrained detector model name cannot be None.")
|
||||||
|
best_detector = BayesianDetectorModel.from_pretrained(repo_name).to(DEVICE)
|
||||||
|
|
||||||
|
best_detector.config.set_detector_information(
|
||||||
|
model_name=model_name, watermarking_config=DEFAULT_WATERMARKING_CONFIG
|
||||||
|
)
|
||||||
|
if save_model_to_hf_hub:
|
||||||
|
upload_model_to_hf(best_detector, repo_name)
|
||||||
|
|
||||||
|
# Evaluate model response with the detector
|
||||||
|
if eval_detector_on_prompts:
|
||||||
|
model_name = best_detector.config.model_name
|
||||||
|
watermark_config_dict = best_detector.config.watermarking_config
|
||||||
|
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermark_config_dict, device=DEVICE)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
synthid_text_detector = SynthIDTextWatermarkDetector(best_detector, logits_processor, tokenizer)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
||||||
|
watermarking_config = SynthIDTextWatermarkingConfig(**watermark_config_dict)
|
||||||
|
|
||||||
|
prompts = ["Write a essay on cats."]
|
||||||
|
inputs = tokenizer(
|
||||||
|
prompts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
).to(DEVICE)
|
||||||
|
|
||||||
|
_, inputs_len = inputs["input_ids"].shape
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
watermarking_config=watermarking_config,
|
||||||
|
do_sample=True,
|
||||||
|
max_length=inputs_len + generation_length,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=40,
|
||||||
|
top_p=1.0,
|
||||||
|
)
|
||||||
|
outputs = outputs[:, inputs_len:]
|
||||||
|
result = synthid_text_detector(outputs)
|
||||||
|
|
||||||
|
# You should set this based on expected fpr (false positive rate) and tpr (true positive rate).
|
||||||
|
# Check our demo at HF Spaces for more info.
|
||||||
|
upper_threshold = 0.95
|
||||||
|
lower_threshold = 0.12
|
||||||
|
if result[0][0] > upper_threshold:
|
||||||
|
print("The text is watermarked.")
|
||||||
|
elif lower_threshold < result[0][0] < upper_threshold:
|
||||||
|
print("It is hard to determine if the text is watermarked or not.")
|
||||||
|
else:
|
||||||
|
print("The text is not watermarked.")
|
5
examples/research_projects/synthid_text/requirements.txt
Normal file
5
examples/research_projects/synthid_text/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
tensorflow-datasets>=4.9.3
|
||||||
|
torch >= 1.3
|
||||||
|
datasets
|
||||||
|
scikit-learn
|
||||||
|
tensorflow
|
408
examples/research_projects/synthid_text/utils.py
Normal file
408
examples/research_projects/synthid_text/utils.py
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 Google DeepMind.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from huggingface_hub import HfApi, create_repo
|
||||||
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
from sklearn import model_selection
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
|
def pad_to_len(
|
||||||
|
arr: torch.Tensor,
|
||||||
|
target_len: int,
|
||||||
|
left_pad: bool,
|
||||||
|
eos_token: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Pad or truncate array to given length."""
|
||||||
|
if arr.shape[1] < target_len:
|
||||||
|
shape_for_ones = list(arr.shape)
|
||||||
|
shape_for_ones[1] = target_len - shape_for_ones[1]
|
||||||
|
padded = (
|
||||||
|
torch.ones(
|
||||||
|
shape_for_ones,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
* eos_token
|
||||||
|
)
|
||||||
|
if not left_pad:
|
||||||
|
arr = torch.concatenate((arr, padded), dim=1)
|
||||||
|
else:
|
||||||
|
arr = torch.concatenate((padded, arr), dim=1)
|
||||||
|
else:
|
||||||
|
arr = arr[:, :target_len]
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
def filter_and_truncate(
|
||||||
|
outputs: torch.Tensor,
|
||||||
|
truncation_length: Optional[int],
|
||||||
|
eos_token_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Filter and truncate outputs to given length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: output tensor of shape [batch_size, output_len]
|
||||||
|
truncation_length: Length to truncate the final output.
|
||||||
|
eos_token_mask: EOS token mask of shape [batch_size, output_len]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output tensor of shape [batch_size, truncation_length].
|
||||||
|
"""
|
||||||
|
if truncation_length:
|
||||||
|
outputs = outputs[:, :truncation_length]
|
||||||
|
truncation_mask = torch.sum(eos_token_mask, dim=1) >= truncation_length
|
||||||
|
return outputs[truncation_mask, :]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def process_outputs_for_training(
|
||||||
|
all_outputs: List[torch.Tensor],
|
||||||
|
logits_processor: transformers.generation.SynthIDTextWatermarkLogitsProcessor,
|
||||||
|
tokenizer: Any,
|
||||||
|
pos_truncation_length: Optional[int],
|
||||||
|
neg_truncation_length: Optional[int],
|
||||||
|
max_length: int,
|
||||||
|
is_cv: bool,
|
||||||
|
is_pos: bool,
|
||||||
|
torch_device: torch.device,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
"""Process raw model outputs into format understandable by the detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_outputs: sequence of outputs of shape [batch_size, output_len].
|
||||||
|
logits_processor: logits processor used for watermarking.
|
||||||
|
tokenizer: tokenizer used for the model.
|
||||||
|
pos_truncation_length: Length to truncate wm outputs.
|
||||||
|
neg_truncation_length: Length to truncate uwm outputs.
|
||||||
|
max_length: Length to pad truncated outputs so that all processed entries.
|
||||||
|
have same shape.
|
||||||
|
is_cv: Process given outputs for cross validation.
|
||||||
|
is_pos: Process given outputs for positives.
|
||||||
|
torch_device: torch device to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of
|
||||||
|
all_masks: list of masks of shape [batch_size, max_length].
|
||||||
|
all_g_values: list of g_values of shape [batch_size, max_length, depth].
|
||||||
|
"""
|
||||||
|
all_masks = []
|
||||||
|
all_g_values = []
|
||||||
|
for outputs in tqdm.tqdm(all_outputs):
|
||||||
|
# outputs is of shape [batch_size, output_len].
|
||||||
|
# output_len can differ from batch to batch.
|
||||||
|
eos_token_mask = logits_processor.compute_eos_token_mask(
|
||||||
|
input_ids=outputs,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
if is_pos or is_cv:
|
||||||
|
# filter with length for positives for both train and CV.
|
||||||
|
# We also filter for length when CV negatives are processed.
|
||||||
|
outputs = filter_and_truncate(outputs, pos_truncation_length, eos_token_mask)
|
||||||
|
elif not is_pos and not is_cv:
|
||||||
|
outputs = filter_and_truncate(outputs, neg_truncation_length, eos_token_mask)
|
||||||
|
|
||||||
|
# If no filtered outputs skip this batch.
|
||||||
|
if outputs.shape[0] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# All outputs are padded to max-length with eos-tokens.
|
||||||
|
outputs = pad_to_len(outputs, max_length, False, tokenizer.eos_token_id, torch_device)
|
||||||
|
# outputs shape [num_filtered_entries, max_length]
|
||||||
|
|
||||||
|
eos_token_mask = logits_processor.compute_eos_token_mask(
|
||||||
|
input_ids=outputs,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context_repetition_mask = logits_processor.compute_context_repetition_mask(
|
||||||
|
input_ids=outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# context_repetition_mask of shape [num_filtered_entries, max_length -
|
||||||
|
# (ngram_len - 1)].
|
||||||
|
context_repetition_mask = pad_to_len(context_repetition_mask, max_length, True, 0, torch_device)
|
||||||
|
# We pad on left to get same max_length shape.
|
||||||
|
# context_repetition_mask of shape [num_filtered_entries, max_length].
|
||||||
|
combined_mask = context_repetition_mask * eos_token_mask
|
||||||
|
|
||||||
|
g_values = logits_processor.compute_g_values(
|
||||||
|
input_ids=outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# g_values of shape [num_filtered_entries, max_length - (ngram_len - 1),
|
||||||
|
# depth].
|
||||||
|
g_values = pad_to_len(g_values, max_length, True, 0, torch_device)
|
||||||
|
|
||||||
|
# We pad on left to get same max_length shape.
|
||||||
|
# g_values of shape [num_filtered_entries, max_length, depth].
|
||||||
|
all_masks.append(combined_mask)
|
||||||
|
all_g_values.append(g_values)
|
||||||
|
return all_masks, all_g_values
|
||||||
|
|
||||||
|
|
||||||
|
def tpr_at_fpr(detector, detector_inputs, w_true, minibatch_size, target_fpr=0.01) -> torch.Tensor:
|
||||||
|
"""Calculates true positive rate (TPR) at false positive rate (FPR)=target_fpr."""
|
||||||
|
positive_idxs = w_true == 1
|
||||||
|
negative_idxs = w_true == 0
|
||||||
|
num_samples = detector_inputs[0].size(0)
|
||||||
|
|
||||||
|
w_preds = []
|
||||||
|
for start in range(0, num_samples, minibatch_size):
|
||||||
|
end = start + minibatch_size
|
||||||
|
detector_inputs_ = (
|
||||||
|
detector_inputs[0][start:end],
|
||||||
|
detector_inputs[1][start:end],
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
w_pred = detector(*detector_inputs_)[0]
|
||||||
|
w_preds.append(w_pred)
|
||||||
|
|
||||||
|
w_pred = torch.cat(w_preds, dim=0) # Concatenate predictions
|
||||||
|
positive_scores = w_pred[positive_idxs]
|
||||||
|
negative_scores = w_pred[negative_idxs]
|
||||||
|
|
||||||
|
# Calculate the FPR threshold
|
||||||
|
# Note: percentile -> quantile
|
||||||
|
fpr_threshold = torch.quantile(negative_scores, 1 - target_fpr)
|
||||||
|
# Note: need to switch to FP32 since torch.mean doesn't work with torch.bool
|
||||||
|
return torch.mean((positive_scores >= fpr_threshold).to(dtype=torch.float32)).item() # TPR
|
||||||
|
|
||||||
|
|
||||||
|
def update_fn_if_fpr_tpr(detector, g_values_val, mask_val, watermarked_val, minibatch_size):
|
||||||
|
"""Loss function for negative TPR@FPR=1% as the validation loss."""
|
||||||
|
tpr_ = tpr_at_fpr(
|
||||||
|
detector=detector,
|
||||||
|
detector_inputs=(g_values_val, mask_val),
|
||||||
|
w_true=watermarked_val,
|
||||||
|
minibatch_size=minibatch_size,
|
||||||
|
)
|
||||||
|
return -tpr_
|
||||||
|
|
||||||
|
|
||||||
|
def process_raw_model_outputs(
|
||||||
|
logits_processor,
|
||||||
|
tokenizer,
|
||||||
|
pos_truncation_length,
|
||||||
|
neg_truncation_length,
|
||||||
|
max_padded_length,
|
||||||
|
tokenized_wm_outputs,
|
||||||
|
test_size,
|
||||||
|
tokenized_uwm_outputs,
|
||||||
|
torch_device,
|
||||||
|
):
|
||||||
|
# Split data into train and CV
|
||||||
|
train_wm_outputs, cv_wm_outputs = model_selection.train_test_split(tokenized_wm_outputs, test_size=test_size)
|
||||||
|
|
||||||
|
train_uwm_outputs, cv_uwm_outputs = model_selection.train_test_split(tokenized_uwm_outputs, test_size=test_size)
|
||||||
|
|
||||||
|
process_kwargs = {
|
||||||
|
"logits_processor": logits_processor,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"pos_truncation_length": pos_truncation_length,
|
||||||
|
"neg_truncation_length": neg_truncation_length,
|
||||||
|
"max_length": max_padded_length,
|
||||||
|
"torch_device": torch_device,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process both train and CV data for training
|
||||||
|
wm_masks_train, wm_g_values_train = process_outputs_for_training(
|
||||||
|
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_wm_outputs],
|
||||||
|
is_pos=True,
|
||||||
|
is_cv=False,
|
||||||
|
**process_kwargs,
|
||||||
|
)
|
||||||
|
wm_masks_cv, wm_g_values_cv = process_outputs_for_training(
|
||||||
|
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_wm_outputs],
|
||||||
|
is_pos=True,
|
||||||
|
is_cv=True,
|
||||||
|
**process_kwargs,
|
||||||
|
)
|
||||||
|
uwm_masks_train, uwm_g_values_train = process_outputs_for_training(
|
||||||
|
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_uwm_outputs],
|
||||||
|
is_pos=False,
|
||||||
|
is_cv=False,
|
||||||
|
**process_kwargs,
|
||||||
|
)
|
||||||
|
uwm_masks_cv, uwm_g_values_cv = process_outputs_for_training(
|
||||||
|
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_uwm_outputs],
|
||||||
|
is_pos=False,
|
||||||
|
is_cv=True,
|
||||||
|
**process_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We get list of data; here we concat all together to be passed to the detector.
|
||||||
|
def pack(mask, g_values):
|
||||||
|
mask = torch.cat(mask, dim=0)
|
||||||
|
g = torch.cat(g_values, dim=0)
|
||||||
|
return mask, g
|
||||||
|
|
||||||
|
wm_masks_train, wm_g_values_train = pack(wm_masks_train, wm_g_values_train)
|
||||||
|
# Note: Use float instead of bool. Otherwise, the entropy calculation doesn't work
|
||||||
|
wm_labels_train = torch.ones((wm_masks_train.shape[0],), dtype=torch.float, device=torch_device)
|
||||||
|
|
||||||
|
wm_masks_cv, wm_g_values_cv = pack(wm_masks_cv, wm_g_values_cv)
|
||||||
|
wm_labels_cv = torch.ones((wm_masks_cv.shape[0],), dtype=torch.float, device=torch_device)
|
||||||
|
|
||||||
|
uwm_masks_train, uwm_g_values_train = pack(uwm_masks_train, uwm_g_values_train)
|
||||||
|
uwm_labels_train = torch.zeros((uwm_masks_train.shape[0],), dtype=torch.float, device=torch_device)
|
||||||
|
|
||||||
|
uwm_masks_cv, uwm_g_values_cv = pack(uwm_masks_cv, uwm_g_values_cv)
|
||||||
|
uwm_labels_cv = torch.zeros((uwm_masks_cv.shape[0],), dtype=torch.float, device=torch_device)
|
||||||
|
|
||||||
|
# Concat pos and negatives data together.
|
||||||
|
train_g_values = torch.cat((wm_g_values_train, uwm_g_values_train), dim=0).squeeze()
|
||||||
|
train_labels = torch.cat((wm_labels_train, uwm_labels_train), axis=0).squeeze()
|
||||||
|
train_masks = torch.cat((wm_masks_train, uwm_masks_train), axis=0).squeeze()
|
||||||
|
|
||||||
|
cv_g_values = torch.cat((wm_g_values_cv, uwm_g_values_cv), axis=0).squeeze()
|
||||||
|
cv_labels = torch.cat((wm_labels_cv, uwm_labels_cv), axis=0).squeeze()
|
||||||
|
cv_masks = torch.cat((wm_masks_cv, uwm_masks_cv), axis=0).squeeze()
|
||||||
|
|
||||||
|
# Shuffle data.
|
||||||
|
shuffled_idx = torch.randperm(train_g_values.shape[0]) # Use torch for GPU compatibility
|
||||||
|
|
||||||
|
train_g_values = train_g_values[shuffled_idx]
|
||||||
|
train_labels = train_labels[shuffled_idx]
|
||||||
|
train_masks = train_masks[shuffled_idx]
|
||||||
|
|
||||||
|
# Shuffle the cross-validation data
|
||||||
|
shuffled_idx_cv = torch.randperm(cv_g_values.shape[0]) # Use torch for GPU compatibility
|
||||||
|
cv_g_values = cv_g_values[shuffled_idx_cv]
|
||||||
|
cv_labels = cv_labels[shuffled_idx_cv]
|
||||||
|
cv_masks = cv_masks[shuffled_idx_cv]
|
||||||
|
|
||||||
|
# Del some variables so we free up GPU memory.
|
||||||
|
del (
|
||||||
|
wm_g_values_train,
|
||||||
|
wm_labels_train,
|
||||||
|
wm_masks_train,
|
||||||
|
wm_g_values_cv,
|
||||||
|
wm_labels_cv,
|
||||||
|
wm_masks_cv,
|
||||||
|
)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return train_g_values, train_masks, train_labels, cv_g_values, cv_masks, cv_labels
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenized_uwm_outputs(num_negatives, neg_batch_size, tokenizer, device):
|
||||||
|
dataset, info = tfds.load("wikipedia/20230601.en", split="train", with_info=True)
|
||||||
|
dataset = dataset.take(num_negatives)
|
||||||
|
|
||||||
|
# Convert the dataset to a DataFrame
|
||||||
|
df = tfds.as_dataframe(dataset, info)
|
||||||
|
ds = tf.data.Dataset.from_tensor_slices(dict(df))
|
||||||
|
tf.random.set_seed(0)
|
||||||
|
ds = ds.shuffle(buffer_size=10_000)
|
||||||
|
ds = ds.batch(batch_size=neg_batch_size)
|
||||||
|
|
||||||
|
tokenized_uwm_outputs = []
|
||||||
|
# Pad to this length (on the right) for batching.
|
||||||
|
padded_length = 1000
|
||||||
|
for i, batch in tqdm.tqdm(enumerate(ds)):
|
||||||
|
responses = [val.decode() for val in batch["text"].numpy()]
|
||||||
|
inputs = tokenizer(
|
||||||
|
responses,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
).to(device)
|
||||||
|
inputs = inputs["input_ids"].cpu().numpy()
|
||||||
|
if inputs.shape[1] >= padded_length:
|
||||||
|
inputs = inputs[:, :padded_length]
|
||||||
|
else:
|
||||||
|
inputs = np.concatenate(
|
||||||
|
[inputs, np.ones((neg_batch_size, padded_length - inputs.shape[1])) * tokenizer.eos_token_id], axis=1
|
||||||
|
)
|
||||||
|
tokenized_uwm_outputs.append(inputs)
|
||||||
|
if len(tokenized_uwm_outputs) * neg_batch_size > num_negatives:
|
||||||
|
break
|
||||||
|
return tokenized_uwm_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenized_wm_outputs(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
watermark_config,
|
||||||
|
num_pos_batches,
|
||||||
|
pos_batch_size,
|
||||||
|
temperature,
|
||||||
|
max_output_len,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
eli5_prompts = datasets.load_dataset("Pavithree/eli5")
|
||||||
|
|
||||||
|
wm_outputs = []
|
||||||
|
|
||||||
|
for batch_id in tqdm.tqdm(range(num_pos_batches)):
|
||||||
|
prompts = eli5_prompts["train"]["title"][batch_id * pos_batch_size : (batch_id + 1) * pos_batch_size]
|
||||||
|
prompts = [prompt.strip('"') for prompt in prompts]
|
||||||
|
inputs = tokenizer(
|
||||||
|
prompts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
).to(device)
|
||||||
|
_, inputs_len = inputs["input_ids"].shape
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
watermarking_config=watermark_config,
|
||||||
|
do_sample=True,
|
||||||
|
max_length=inputs_len + max_output_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
wm_outputs.append(outputs[:, inputs_len:].cpu().detach())
|
||||||
|
|
||||||
|
del outputs, inputs, prompts
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return wm_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def upload_model_to_hf(model, hf_repo_name: str, private: bool = True):
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
# Check if the repository exists
|
||||||
|
try:
|
||||||
|
api.repo_info(repo_id=hf_repo_name, use_auth_token=True)
|
||||||
|
print(f"Repository '{hf_repo_name}' already exists.")
|
||||||
|
except RepositoryNotFoundError:
|
||||||
|
# If the repository does not exist, create it
|
||||||
|
print(f"Repository '{hf_repo_name}' not found. Creating it...")
|
||||||
|
create_repo(repo_id=hf_repo_name, private=private, use_auth_token=True)
|
||||||
|
print(f"Repository '{hf_repo_name}' created successfully.")
|
||||||
|
|
||||||
|
# Push the model to the Hugging Face Hub
|
||||||
|
print(f"Uploading model to Hugging Face repo '{hf_repo_name}'...")
|
||||||
|
model.push_to_hub(repo_id=hf_repo_name, use_auth_token=True)
|
@ -1301,6 +1301,8 @@ else:
|
|||||||
_import_structure["generation"].extend(
|
_import_structure["generation"].extend(
|
||||||
[
|
[
|
||||||
"AlternatingCodebooksLogitsProcessor",
|
"AlternatingCodebooksLogitsProcessor",
|
||||||
|
"BayesianDetectorConfig",
|
||||||
|
"BayesianDetectorModel",
|
||||||
"BeamScorer",
|
"BeamScorer",
|
||||||
"BeamSearchScorer",
|
"BeamSearchScorer",
|
||||||
"ClassifierFreeGuidanceLogitsProcessor",
|
"ClassifierFreeGuidanceLogitsProcessor",
|
||||||
@ -1339,6 +1341,9 @@ else:
|
|||||||
"StopStringCriteria",
|
"StopStringCriteria",
|
||||||
"SuppressTokensAtBeginLogitsProcessor",
|
"SuppressTokensAtBeginLogitsProcessor",
|
||||||
"SuppressTokensLogitsProcessor",
|
"SuppressTokensLogitsProcessor",
|
||||||
|
"SynthIDTextWatermarkDetector",
|
||||||
|
"SynthIDTextWatermarkingConfig",
|
||||||
|
"SynthIDTextWatermarkLogitsProcessor",
|
||||||
"TemperatureLogitsWarper",
|
"TemperatureLogitsWarper",
|
||||||
"TopKLogitsWarper",
|
"TopKLogitsWarper",
|
||||||
"TopPLogitsWarper",
|
"TopPLogitsWarper",
|
||||||
@ -6213,6 +6218,8 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .generation import (
|
from .generation import (
|
||||||
AlternatingCodebooksLogitsProcessor,
|
AlternatingCodebooksLogitsProcessor,
|
||||||
|
BayesianDetectorConfig,
|
||||||
|
BayesianDetectorModel,
|
||||||
BeamScorer,
|
BeamScorer,
|
||||||
BeamSearchScorer,
|
BeamSearchScorer,
|
||||||
ClassifierFreeGuidanceLogitsProcessor,
|
ClassifierFreeGuidanceLogitsProcessor,
|
||||||
@ -6251,6 +6258,9 @@ if TYPE_CHECKING:
|
|||||||
StopStringCriteria,
|
StopStringCriteria,
|
||||||
SuppressTokensAtBeginLogitsProcessor,
|
SuppressTokensAtBeginLogitsProcessor,
|
||||||
SuppressTokensLogitsProcessor,
|
SuppressTokensLogitsProcessor,
|
||||||
|
SynthIDTextWatermarkDetector,
|
||||||
|
SynthIDTextWatermarkingConfig,
|
||||||
|
SynthIDTextWatermarkLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
|
@ -18,7 +18,13 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"],
|
"configuration_utils": [
|
||||||
|
"BaseWatermarkingConfig",
|
||||||
|
"GenerationConfig",
|
||||||
|
"GenerationMode",
|
||||||
|
"SynthIDTextWatermarkingConfig",
|
||||||
|
"WatermarkingConfig",
|
||||||
|
],
|
||||||
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,6 +77,7 @@ else:
|
|||||||
"SequenceBiasLogitsProcessor",
|
"SequenceBiasLogitsProcessor",
|
||||||
"SuppressTokensLogitsProcessor",
|
"SuppressTokensLogitsProcessor",
|
||||||
"SuppressTokensAtBeginLogitsProcessor",
|
"SuppressTokensAtBeginLogitsProcessor",
|
||||||
|
"SynthIDTextWatermarkLogitsProcessor",
|
||||||
"TemperatureLogitsWarper",
|
"TemperatureLogitsWarper",
|
||||||
"TopKLogitsWarper",
|
"TopKLogitsWarper",
|
||||||
"TopPLogitsWarper",
|
"TopPLogitsWarper",
|
||||||
@ -110,6 +117,9 @@ else:
|
|||||||
_import_structure["watermarking"] = [
|
_import_structure["watermarking"] = [
|
||||||
"WatermarkDetector",
|
"WatermarkDetector",
|
||||||
"WatermarkDetectorOutput",
|
"WatermarkDetectorOutput",
|
||||||
|
"BayesianDetectorModel",
|
||||||
|
"BayesianDetectorConfig",
|
||||||
|
"SynthIDTextWatermarkDetector",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -179,7 +189,13 @@ else:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig
|
from .configuration_utils import (
|
||||||
|
BaseWatermarkingConfig,
|
||||||
|
GenerationConfig,
|
||||||
|
GenerationMode,
|
||||||
|
SynthIDTextWatermarkingConfig,
|
||||||
|
WatermarkingConfig,
|
||||||
|
)
|
||||||
from .streamers import TextIteratorStreamer, TextStreamer
|
from .streamers import TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -217,6 +233,7 @@ if TYPE_CHECKING:
|
|||||||
SequenceBiasLogitsProcessor,
|
SequenceBiasLogitsProcessor,
|
||||||
SuppressTokensAtBeginLogitsProcessor,
|
SuppressTokensAtBeginLogitsProcessor,
|
||||||
SuppressTokensLogitsProcessor,
|
SuppressTokensLogitsProcessor,
|
||||||
|
SynthIDTextWatermarkLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
@ -254,6 +271,9 @@ if TYPE_CHECKING:
|
|||||||
SampleEncoderDecoderOutput,
|
SampleEncoderDecoderOutput,
|
||||||
)
|
)
|
||||||
from .watermarking import (
|
from .watermarking import (
|
||||||
|
BayesianDetectorConfig,
|
||||||
|
BayesianDetectorModel,
|
||||||
|
SynthIDTextWatermarkDetector,
|
||||||
WatermarkDetector,
|
WatermarkDetector,
|
||||||
WatermarkDetectorOutput,
|
WatermarkDetectorOutput,
|
||||||
)
|
)
|
||||||
|
@ -18,8 +18,9 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, is_dataclass
|
from dataclasses import dataclass, is_dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
@ -59,6 +60,7 @@ if is_torch_available():
|
|||||||
StaticCache,
|
StaticCache,
|
||||||
StaticCacheConfig,
|
StaticCacheConfig,
|
||||||
)
|
)
|
||||||
|
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
||||||
|
|
||||||
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
|
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
|
||||||
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
|
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
|
||||||
@ -280,23 +282,10 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
low_memory (`bool`, *optional*):
|
low_memory (`bool`, *optional*):
|
||||||
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
||||||
Used with beam search and contrastive search.
|
Used with beam search and contrastive search.
|
||||||
watermarking_config (`WatermarkingConfig` or `dict`, *optional*):
|
watermarking_config (`BaseWatermarkingConfig` or `dict`, *optional*):
|
||||||
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens.
|
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green"
|
||||||
If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
|
tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more
|
||||||
See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys:
|
details. If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
|
||||||
- greenlist_ratio (`float`):
|
|
||||||
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
|
||||||
- bias (`float`):
|
|
||||||
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
|
||||||
- hashing_key (`int`):
|
|
||||||
Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
|
||||||
- seeding_scheme (`str`):
|
|
||||||
Algorithm to use for watermarking. Accepts values:
|
|
||||||
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
|
||||||
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
|
||||||
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
|
||||||
- context_width (`int`):
|
|
||||||
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
|
||||||
|
|
||||||
> Parameters that define the output variables of generate
|
> Parameters that define the output variables of generate
|
||||||
|
|
||||||
@ -430,7 +419,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
watermarking_config = kwargs.pop("watermarking_config", None)
|
watermarking_config = kwargs.pop("watermarking_config", None)
|
||||||
if watermarking_config is None:
|
if watermarking_config is None:
|
||||||
self.watermarking_config = None
|
self.watermarking_config = None
|
||||||
elif isinstance(watermarking_config, WatermarkingConfig):
|
elif isinstance(watermarking_config, BaseWatermarkingConfig):
|
||||||
self.watermarking_config = watermarking_config
|
self.watermarking_config = watermarking_config
|
||||||
else:
|
else:
|
||||||
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
|
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
|
||||||
@ -766,7 +755,15 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
|
|
||||||
# 6. check watermarking arguments
|
# 6. check watermarking arguments
|
||||||
if self.watermarking_config is not None:
|
if self.watermarking_config is not None:
|
||||||
if not isinstance(self.watermarking_config, WatermarkingConfig):
|
if not (
|
||||||
|
isinstance(self.watermarking_config, WatermarkingConfig)
|
||||||
|
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
|
||||||
|
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||||
self.watermarking_config.validate()
|
self.watermarking_config.validate()
|
||||||
|
|
||||||
@ -1287,52 +1284,20 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WatermarkingConfig:
|
class BaseWatermarkingConfig(ABC):
|
||||||
"""
|
"""Generic watermarking config"""
|
||||||
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
|
||||||
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
|
|
||||||
|
|
||||||
Accepts the following keys:
|
|
||||||
- greenlist_ratio (`float`):
|
|
||||||
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
|
||||||
- bias (`float`):
|
|
||||||
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
|
||||||
- hashing_key (`int`):
|
|
||||||
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
|
||||||
- seeding_scheme (`str`):
|
|
||||||
Algorithm to use for watermarking. Accepts values:
|
|
||||||
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
|
||||||
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
|
||||||
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
|
||||||
- context_width(`int`):
|
|
||||||
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
greenlist_ratio: Optional[float] = 0.25,
|
|
||||||
bias: Optional[float] = 2.0,
|
|
||||||
hashing_key: Optional[int] = 15485863,
|
|
||||||
seeding_scheme: Optional[str] = "lefthash",
|
|
||||||
context_width: Optional[int] = 1,
|
|
||||||
):
|
|
||||||
self.greenlist_ratio = greenlist_ratio
|
|
||||||
self.bias = bias
|
|
||||||
self.hashing_key = hashing_key
|
|
||||||
self.seeding_scheme = seeding_scheme
|
|
||||||
self.context_width = context_width
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict, **kwargs):
|
def from_dict(cls, config_dict, **kwargs):
|
||||||
"""
|
"""
|
||||||
Constructs a WatermarkingConfig instance from a dictionary of parameters.
|
Constructs a BaseWatermarkingConfig instance from a dictionary of parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
||||||
**kwargs: Additional keyword arguments to override dictionary values.
|
**kwargs: Additional keyword arguments to override dictionary values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary.
|
BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary.
|
||||||
"""
|
"""
|
||||||
config = cls(**config_dict)
|
config = cls(**config_dict)
|
||||||
to_remove = []
|
to_remove = []
|
||||||
@ -1394,6 +1359,49 @@ class WatermarkingConfig:
|
|||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self): ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def construct_processor(self, vocab_size): ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WatermarkingConfig(BaseWatermarkingConfig):
|
||||||
|
"""
|
||||||
|
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
||||||
|
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
|
||||||
|
|
||||||
|
Accepts the following keys:
|
||||||
|
- greenlist_ratio (`float`):
|
||||||
|
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
||||||
|
- bias (`float`):
|
||||||
|
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
||||||
|
- hashing_key (`int`):
|
||||||
|
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
||||||
|
- seeding_scheme (`str`):
|
||||||
|
Algorithm to use for watermarking. Accepts values:
|
||||||
|
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
||||||
|
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
||||||
|
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
||||||
|
- context_width(`int`):
|
||||||
|
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
greenlist_ratio: Optional[float] = 0.25,
|
||||||
|
bias: Optional[float] = 2.0,
|
||||||
|
hashing_key: Optional[int] = 15485863,
|
||||||
|
seeding_scheme: Optional[str] = "lefthash",
|
||||||
|
context_width: Optional[int] = 1,
|
||||||
|
):
|
||||||
|
self.greenlist_ratio = greenlist_ratio
|
||||||
|
self.bias = bias
|
||||||
|
self.hashing_key = hashing_key
|
||||||
|
self.seeding_scheme = seeding_scheme
|
||||||
|
self.context_width = context_width
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
watermark_missing_arg_msg = (
|
watermark_missing_arg_msg = (
|
||||||
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
||||||
@ -1423,3 +1431,104 @@ class WatermarkingConfig:
|
|||||||
found_value=self.context_width,
|
found_value=self.context_width,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
|
||||||
|
return WatermarkLogitsProcessor(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
device=device,
|
||||||
|
greenlist_ratio=self.greenlist_ratio,
|
||||||
|
bias=self.bias,
|
||||||
|
hashing_key=self.hashing_key,
|
||||||
|
seeding_scheme=self.seeding_scheme,
|
||||||
|
context_width=self.context_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
|
||||||
|
"""
|
||||||
|
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
||||||
|
See [this paper](https://www.nature.com/articles/s41586-024-08025-4) for more details on the arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ngram_len (`int`):
|
||||||
|
Ngram length.
|
||||||
|
keys (`List[int]`):
|
||||||
|
A sequence of watermarking keys, one for each depth.
|
||||||
|
context_history_size (`int`, *optional*, defaults to 1024):
|
||||||
|
Size of the tensor to keep track of seen contexts.
|
||||||
|
sampling_table_seed (`int`, *optional*, defaults to 0):
|
||||||
|
Random seed to generate the sampling table.
|
||||||
|
sampling_table_size (`int`, *optional*, defaults to 65536):
|
||||||
|
Size of the sampling table.
|
||||||
|
skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to skip first ngram calls.
|
||||||
|
debug_mode (`bool`, optional, *optional*, defaults to `False`):
|
||||||
|
Logits are modified to uniform one got before watermarking modification is applied. This is to test the
|
||||||
|
implementation.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it')
|
||||||
|
|
||||||
|
>>> # SynthID Text configuration
|
||||||
|
>>> watermarking_config = SynthIDTextWatermarkingConfig(
|
||||||
|
... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
|
||||||
|
... ngram_len=5,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Generation with watermarking
|
||||||
|
>>> tokenized_prompts = tokenizer(["your prompts here"])
|
||||||
|
>>> output_sequences = model.generate(
|
||||||
|
... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True,
|
||||||
|
... )
|
||||||
|
>>> watermarked_text = tokenizer.batch_decode(output_sequences)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ngram_len: int,
|
||||||
|
keys: List[int],
|
||||||
|
context_history_size: int = 1024,
|
||||||
|
sampling_table_seed: int = 0,
|
||||||
|
sampling_table_size: int = 2**16,
|
||||||
|
skip_first_ngram_calls: bool = False,
|
||||||
|
debug_mode: bool = False,
|
||||||
|
):
|
||||||
|
self.ngram_len = ngram_len
|
||||||
|
self.keys = keys
|
||||||
|
self.sampling_table_size = sampling_table_size
|
||||||
|
self.sampling_table_seed = sampling_table_seed
|
||||||
|
self.context_history_size = context_history_size
|
||||||
|
self.skip_first_ngram_calls = skip_first_ngram_calls
|
||||||
|
self.debug_mode = debug_mode
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
watermark_missing_arg_msg = (
|
||||||
|
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
||||||
|
"but found {found_value}"
|
||||||
|
)
|
||||||
|
if self.sampling_table_size > 2**24:
|
||||||
|
raise ValueError(
|
||||||
|
watermark_missing_arg_msg.format(
|
||||||
|
key="sampling_table_size",
|
||||||
|
correct_value="< 2**24",
|
||||||
|
found_value=self.sampling_table_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
|
||||||
|
return SynthIDTextWatermarkLogitsProcessor(
|
||||||
|
ngram_len=self.ngram_len,
|
||||||
|
keys=self.keys,
|
||||||
|
sampling_table_size=self.sampling_table_size,
|
||||||
|
sampling_table_seed=self.sampling_table_seed,
|
||||||
|
context_history_size=self.context_history_size,
|
||||||
|
device=device,
|
||||||
|
skip_first_ngram_calls=self.skip_first_ngram_calls,
|
||||||
|
debug_mode=self.debug_mode,
|
||||||
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2020 The HuggingFace Inc. team
|
# Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -2460,6 +2460,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
final_greenlist.append(greedy_predictions[i])
|
final_greenlist.append(greedy_predictions[i])
|
||||||
return torch.tensor(final_greenlist, device=input_seq.device)
|
return torch.tensor(final_greenlist, device=input_seq.device)
|
||||||
|
|
||||||
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if input_ids.shape[-1] < self.context_width:
|
if input_ids.shape[-1] < self.context_width:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -2477,3 +2478,478 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias
|
scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias
|
||||||
|
|
||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
|
class SynthIDTextWatermarkState:
|
||||||
|
"""SynthID watermarking state."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
ngram_len: int,
|
||||||
|
context_history_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Initializes the state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (`int`): Batch size.
|
||||||
|
ngram_len (`int`): Ngram length.
|
||||||
|
context_history_size (`int`): Size of the tensor to keep track of seen contexts.
|
||||||
|
device (`int`): Device to use.
|
||||||
|
"""
|
||||||
|
self.context = torch.zeros(
|
||||||
|
(batch_size, ngram_len - 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.context_history = torch.zeros(
|
||||||
|
(batch_size, context_history_size),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.num_calls = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
Logits processor that implements watermarking techniques for text generation models.
|
||||||
|
This class facilitates the application of SynthID text watermarking, a method for embedding imperceptible signals
|
||||||
|
into generated text to aid in detecting synthetic content. It operates by subtly manipulating the probabilities of
|
||||||
|
token selection during text generation in a manner that can be reliably recovered later for verification.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
* **State Management:** Maintains internal state to track token sequences and generate watermarking keys
|
||||||
|
dynamically.
|
||||||
|
|
||||||
|
* **Key Generation:** Computes hashes based on token sequences and watermarking parameters to create unique keys
|
||||||
|
for each position.
|
||||||
|
|
||||||
|
* **G-Value Sampling:** Employs a pre-computed sampling table to sample watermarking values (g-values) based on
|
||||||
|
the generated keys.
|
||||||
|
|
||||||
|
* **Score Adjustment:** Applies calculated g-values to modify token probabilities during generation, embedding the
|
||||||
|
watermark.
|
||||||
|
|
||||||
|
* **Context Repetition Handling:** Incorporates logic to avoid watermarking tokens in repeated contexts,
|
||||||
|
preserving naturalness.
|
||||||
|
|
||||||
|
* **EOS Token Masking:** Supports masking end-of-sentence tokens to prevent their inclusion in watermarking
|
||||||
|
calculations.
|
||||||
|
|
||||||
|
* **Utility Functions:** Provides functions to compute g-values directly, check for context repetition, create
|
||||||
|
EOS token masks, and estimate expected mean g-values.
|
||||||
|
|
||||||
|
Refer to paper url: https://www.nature.com/articles/s41586-024-08025-4 for more details around this.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ngram_len (`int`):
|
||||||
|
Ngram length.
|
||||||
|
keys (`List[int]`):
|
||||||
|
A sequence of watermarking keys, one for each depth.
|
||||||
|
sampling_table_size (`int`):
|
||||||
|
Size of the sampling table.
|
||||||
|
sampling_table_seed (`int`):
|
||||||
|
Random seed to generate the sampling table.
|
||||||
|
context_history_size (`int`):
|
||||||
|
Size of the tensor to keep track of seen contexts.
|
||||||
|
device (`torch.device`):
|
||||||
|
Device to use.
|
||||||
|
skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to skip first ngram calls.
|
||||||
|
debug_mode (`bool`, optional, *optional*, defaults to `False`):
|
||||||
|
Logits are modified to uniform one got before watermarking modification is applied. This is to test the
|
||||||
|
implementation.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it')
|
||||||
|
|
||||||
|
>>> # SynthID Text configuration
|
||||||
|
>>> watermarking_config = SynthIDTextWatermarkingConfig(
|
||||||
|
... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
|
||||||
|
... ngram_len=5,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Generation with watermarking
|
||||||
|
>>> tokenized_prompts = tokenizer(["your prompts here"])
|
||||||
|
>>> output_sequences = model.generate(
|
||||||
|
... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True,
|
||||||
|
... )
|
||||||
|
>>> watermarked_text = tokenizer.batch_decode(output_sequences)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ngram_len: int,
|
||||||
|
keys: List[int],
|
||||||
|
sampling_table_size: int,
|
||||||
|
sampling_table_seed: int,
|
||||||
|
context_history_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
skip_first_ngram_calls: bool = False,
|
||||||
|
debug_mode: bool = False,
|
||||||
|
):
|
||||||
|
self.ngram_len = ngram_len
|
||||||
|
self.keys = torch.tensor(keys, device=device)
|
||||||
|
|
||||||
|
generator = torch.Generator(device=device).manual_seed(sampling_table_seed)
|
||||||
|
# A random sampling table is pre-computed and modulo table size is applied to map from a hash of ngram keys to
|
||||||
|
# g values, this is similar to the hashtable implementation used in
|
||||||
|
# https://github.com/facebookresearch/three_bricks. We note that the hashing employed in this repository is
|
||||||
|
# different from that used to watermark the Gemini App, and hence the detectors trained based on the
|
||||||
|
# hashing in this repository will not transfer to text generated by the Gemini App.
|
||||||
|
self.sampling_table = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=2,
|
||||||
|
size=(sampling_table_size,),
|
||||||
|
generator=generator,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.context_history_size = context_history_size
|
||||||
|
self.device = device
|
||||||
|
self.state = None
|
||||||
|
self.skip_first_ngram_calls = skip_first_ngram_calls
|
||||||
|
self.debug_mode = debug_mode
|
||||||
|
|
||||||
|
def _init_state(self, batch_size: int):
|
||||||
|
"""Initializes the state."""
|
||||||
|
self.state = SynthIDTextWatermarkState(
|
||||||
|
batch_size=batch_size,
|
||||||
|
ngram_len=self.ngram_len,
|
||||||
|
context_history_size=self.context_history_size,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_scores(self, scores: torch.FloatTensor, g_values: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
"""Updates scores using the g values.
|
||||||
|
|
||||||
|
We assume that the scores are in the log space.
|
||||||
|
Args:
|
||||||
|
scores (`torch.FloatTensor`): Scores (batch_size, vocab_size).
|
||||||
|
g_values (`torch.FloatTensor`): G valus (batch_size, vocab_size, depth).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated scores (batch_size, vocab_size).
|
||||||
|
"""
|
||||||
|
_, _, depth = g_values.shape
|
||||||
|
|
||||||
|
probs = torch.softmax(scores, dim=1)
|
||||||
|
|
||||||
|
for i in range(depth):
|
||||||
|
g_values_at_depth = g_values[:, :, i]
|
||||||
|
g_mass_at_depth = (g_values_at_depth * probs).sum(axis=1, keepdims=True)
|
||||||
|
probs = probs * (1 + g_values_at_depth - g_mass_at_depth)
|
||||||
|
|
||||||
|
log_probs = torch.log(probs)
|
||||||
|
log_probs = torch.where(torch.isfinite(log_probs), log_probs, torch.finfo(log_probs.dtype).min)
|
||||||
|
return log_probs
|
||||||
|
|
||||||
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
self._check_input_ids_shape(input_ids)
|
||||||
|
batch_size, vocab_size = scores.shape
|
||||||
|
|
||||||
|
if self.debug_mode:
|
||||||
|
scores = torch.ones_like(scores)
|
||||||
|
|
||||||
|
# Currently indices is just a arange to compute watermarking on the desnse logits.
|
||||||
|
all_indices = torch.stack([torch.arange(vocab_size, device=self.device) for _ in range(batch_size)])
|
||||||
|
|
||||||
|
if self.state is None:
|
||||||
|
# Initialize watermarking state if it does not exist.
|
||||||
|
self._init_state(batch_size)
|
||||||
|
else:
|
||||||
|
# Append last input id (which is the input id added in last call) to the
|
||||||
|
# previous context so we have the context to be used for current
|
||||||
|
# watermarking.
|
||||||
|
self.state.context = torch.concat(
|
||||||
|
(self.state.context, input_ids[:, -1:]),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
self.state.context = self.state.context[:, 1:]
|
||||||
|
|
||||||
|
if self.state is None:
|
||||||
|
raise ValueError("self.state can't be None! Call `self._init_state` to initialize the state.")
|
||||||
|
|
||||||
|
self.state.num_calls += 1
|
||||||
|
|
||||||
|
# Don't watermark the first ngram_len - 1 tokens if set.
|
||||||
|
if self.skip_first_ngram_calls and self.state.num_calls < self.ngram_len:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# 2. Generate random keys for each ngram key combination.
|
||||||
|
ngram_keys, hash_result_with_just_context = self._compute_keys(self.state.context, all_indices)
|
||||||
|
# ngram_keys shape [batch_size, top_k, depth]
|
||||||
|
|
||||||
|
# 3. Sample g values.
|
||||||
|
g_values = self.sample_g_values(ngram_keys)
|
||||||
|
# g_values shape [batch_size, top_k, depth]
|
||||||
|
|
||||||
|
# 4. Modify scores.
|
||||||
|
updated_scores = self.update_scores(scores, g_values)
|
||||||
|
# updated scores shape [batch_size, top_k]
|
||||||
|
|
||||||
|
# 5. Check if the current watermarking context was previously used, if yes skip watermarking.
|
||||||
|
hash_result_with_just_context = hash_result_with_just_context[:, None]
|
||||||
|
is_repeated_context = (self.state.context_history == hash_result_with_just_context).any(
|
||||||
|
dim=1,
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
self.state.context_history = torch.concat(
|
||||||
|
(hash_result_with_just_context, self.state.context_history),
|
||||||
|
dim=1,
|
||||||
|
)[:, :-1]
|
||||||
|
|
||||||
|
updated_watermarked_scores = torch.where(
|
||||||
|
is_repeated_context,
|
||||||
|
input=scores,
|
||||||
|
other=updated_scores,
|
||||||
|
)
|
||||||
|
return updated_watermarked_scores
|
||||||
|
|
||||||
|
def accumulate_hash(
|
||||||
|
self,
|
||||||
|
current_hash: torch.LongTensor,
|
||||||
|
data: torch.LongTensor,
|
||||||
|
multiplier: int = 6364136223846793005,
|
||||||
|
increment: int = 1,
|
||||||
|
) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Accumulate hash of data on current hash.
|
||||||
|
|
||||||
|
Method uses adapted linear congruential generator with newlib/musl parameters.
|
||||||
|
|
||||||
|
This function has following property -
|
||||||
|
f(x, data[T]) = f(f(x, data[:T - 1]), data[T])
|
||||||
|
|
||||||
|
This function expects current_hash.shape and data.shape[:-1] to
|
||||||
|
match/broadcastable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_hash (`torch.LongTensor`):
|
||||||
|
(shape,)
|
||||||
|
data (`torch.LongTensor`):
|
||||||
|
(shape, tensor_len)
|
||||||
|
multiplier (`int`, optional, *optional*, defaults to 6364136223846793005):
|
||||||
|
multiplier of linear congruential generator
|
||||||
|
increment (`int`, optional, *optional*, defaults to 1):
|
||||||
|
increment of linear congruential generator
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
updated hash (shape,)
|
||||||
|
"""
|
||||||
|
for i in range(data.shape[-1]):
|
||||||
|
current_hash = torch.add(current_hash, data[..., i])
|
||||||
|
current_hash = torch.mul(current_hash, multiplier)
|
||||||
|
current_hash = torch.add(current_hash, increment)
|
||||||
|
return current_hash
|
||||||
|
|
||||||
|
def compute_ngram_keys(self, ngrams: torch.LongTensor) -> torch.LongTensor:
|
||||||
|
"""Computes random keys for each ngram and depth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ngrams (`torch.LongTensor`):
|
||||||
|
Ngrams (batch_size, num_ngrams, ngram_len).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ngram keys (batch_size, num_ngrams, depth).
|
||||||
|
"""
|
||||||
|
if len(ngrams.shape) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but" f" is {ngrams.shape}"
|
||||||
|
)
|
||||||
|
if ngrams.shape[2] != self.ngram_len:
|
||||||
|
raise ValueError(
|
||||||
|
"Ngrams should be of shape (batch_size, num_ngrams, ngram_len),"
|
||||||
|
f" where ngram_len is {self.ngram_len}, but is {ngrams.shape}"
|
||||||
|
)
|
||||||
|
batch_size, _, _ = ngrams.shape
|
||||||
|
|
||||||
|
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
|
||||||
|
# hash_result shape [batch_size,]
|
||||||
|
# ngrams shape [batch_size, num_ngrams, ngram_len]
|
||||||
|
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(hash_result, ngrams)
|
||||||
|
# hash_result shape [batch_size, num_ngrams]
|
||||||
|
|
||||||
|
keys = self.keys[None, None, :, None]
|
||||||
|
# hash_result shape [batch_size, num_ngrams]
|
||||||
|
# keys shape [1, 1, depth, 1]
|
||||||
|
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys)
|
||||||
|
# hash_result shape [batch_size, num_ngrams, depth]
|
||||||
|
|
||||||
|
return hash_result
|
||||||
|
|
||||||
|
def _compute_keys(
|
||||||
|
self, n_minus_1_grams: torch.LongTensor, indices: torch.LongTensor
|
||||||
|
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
||||||
|
"""Computes random keys for each ngram and depth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_minus_1_grams (`torch.LongTensor`):
|
||||||
|
Ngrams (batch_size, ngram_len - 1).
|
||||||
|
indices (`torch.LongTensor`):
|
||||||
|
indices of the continuations (batch_size, num_indices)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ngram keys (batch_size, num_indices, depth).
|
||||||
|
"""
|
||||||
|
batch_size, _ = n_minus_1_grams.shape
|
||||||
|
|
||||||
|
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
|
||||||
|
# First hash n_minus_1 gram, for each batch entry we have a single
|
||||||
|
# n_minus_1 gram context.
|
||||||
|
# hash_result shape [batch_size]
|
||||||
|
# n_minus_1_gram shape [batch_size, ngram_len - 1]
|
||||||
|
hash_result_with_just_context = self.accumulate_hash(hash_result, n_minus_1_grams)
|
||||||
|
# hash_result shape [batch_size,]
|
||||||
|
# Indices is of shape [batch_size, num_indices], so we make it
|
||||||
|
# [batch_size, num_indices, 1] so we can vmap over num_indices dim.
|
||||||
|
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(
|
||||||
|
hash_result_with_just_context, indices[:, :, None]
|
||||||
|
)
|
||||||
|
# hash_result shape [batch_size, num_indices]
|
||||||
|
# Basically we have a hash for each batch entry and each indices
|
||||||
|
# Now we add watermarking keys to this hash.
|
||||||
|
# keys are of shape [depth,]
|
||||||
|
# We add batch, num_indices and data dimension to this making it
|
||||||
|
# [1, 1, depth, 1].
|
||||||
|
# So we can vmap over the depth dimension for compute_hash
|
||||||
|
keys = self.keys[None, None, :, None]
|
||||||
|
hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys)
|
||||||
|
# hash_result shape should be [batch_size, num_indices, depth]
|
||||||
|
return hash_result, hash_result_with_just_context
|
||||||
|
|
||||||
|
def sample_g_values(self, ngram_keys: torch.LongTensor) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Samples g values from Bernoulli distribution.
|
||||||
|
|
||||||
|
It is not possible to pass random keys in a vectorized way in torch. Instead
|
||||||
|
we pre-compute a random sampling table, and use apply modulo table size to
|
||||||
|
map from ngram keys (int64) to g values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ngram_keys (`torch.LongTensor`):
|
||||||
|
Random keys (batch_size, num_ngrams, depth).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
G values (batch_size, num_ngrams, depth).
|
||||||
|
"""
|
||||||
|
(sampling_table_size,) = self.sampling_table.shape
|
||||||
|
sampling_table = self.sampling_table.reshape((1, 1, sampling_table_size))
|
||||||
|
ngram_keys = ngram_keys % sampling_table_size
|
||||||
|
return torch.take_along_dim(sampling_table, indices=ngram_keys, dim=2)
|
||||||
|
|
||||||
|
def _check_input_ids_shape(self, input_ids: torch.LongTensor):
|
||||||
|
"""Checks the shape of input ids."""
|
||||||
|
if len(input_ids.shape) != 2:
|
||||||
|
raise ValueError("Input ids should be of shape (batch_size, input_len), but is" f" {input_ids.shape}")
|
||||||
|
|
||||||
|
def compute_g_values(self, input_ids: torch.LongTensor) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Computes g values for each ngram from the given sequence of tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor`):
|
||||||
|
Input token ids (batch_size, input_len).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
G values (batch_size, input_len - (ngram_len - 1), depth).
|
||||||
|
"""
|
||||||
|
self._check_input_ids_shape(input_ids)
|
||||||
|
ngrams = input_ids.unfold(dimension=1, size=self.ngram_len, step=1)
|
||||||
|
ngram_keys = self.compute_ngram_keys(ngrams)
|
||||||
|
return self.sample_g_values(ngram_keys)
|
||||||
|
|
||||||
|
def compute_context_repetition_mask(self, input_ids: torch.LongTensor) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Computes repetition mask.
|
||||||
|
|
||||||
|
0 and 1 stand for repeated and not repeated context n-1 grams respectively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor`):
|
||||||
|
Input token ids (batch_size, input_len).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Repetitions mask (batch_size, input_len - (ngram_len - 1)).
|
||||||
|
"""
|
||||||
|
self._check_input_ids_shape(input_ids)
|
||||||
|
batch_size, _ = input_ids.shape
|
||||||
|
state = SynthIDTextWatermarkState(
|
||||||
|
batch_size=batch_size,
|
||||||
|
ngram_len=self.ngram_len,
|
||||||
|
context_history_size=self.context_history_size,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
contexts = input_ids[:, :-1].unfold(
|
||||||
|
dimension=1,
|
||||||
|
size=self.ngram_len - 1,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
_, num_contexts, _ = contexts.shape
|
||||||
|
|
||||||
|
are_repeated_contexts = []
|
||||||
|
for i in range(num_contexts):
|
||||||
|
context = contexts[:, i, :]
|
||||||
|
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
|
||||||
|
context_hash = self.accumulate_hash(hash_result, context)[:, None]
|
||||||
|
is_repeated_context = (state.context_history == context_hash).any(
|
||||||
|
dim=1,
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
are_repeated_contexts.append(is_repeated_context)
|
||||||
|
state.context_history = torch.concat(
|
||||||
|
(context_hash, state.context_history),
|
||||||
|
dim=1,
|
||||||
|
)[:, :-1]
|
||||||
|
are_repeated_contexts = torch.concat(are_repeated_contexts, dim=1)
|
||||||
|
|
||||||
|
return torch.logical_not(are_repeated_contexts)
|
||||||
|
|
||||||
|
def compute_eos_token_mask(self, input_ids: torch.LongTensor, eos_token_id: int) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Computes repetitions mask.
|
||||||
|
|
||||||
|
1 stands for ngrams that don't contain EOS tokens and vice versa.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor`):
|
||||||
|
Input token ids (batch_size, input_len).
|
||||||
|
eos_token_id (`int`):
|
||||||
|
EOS token ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EOS token mask (batch_size, input_len).
|
||||||
|
"""
|
||||||
|
self._check_input_ids_shape(input_ids)
|
||||||
|
noneos_masks = []
|
||||||
|
all_eos_equated = input_ids == eos_token_id
|
||||||
|
for eos_equated in all_eos_equated:
|
||||||
|
nonzero_idx = torch.nonzero(eos_equated)
|
||||||
|
noneos_mask = torch.ones_like(eos_equated)
|
||||||
|
if nonzero_idx.shape[0] != 0:
|
||||||
|
noneos_mask[nonzero_idx[0][0] :] = 0
|
||||||
|
noneos_masks.append(noneos_mask)
|
||||||
|
return torch.stack(noneos_masks, dim=0)
|
||||||
|
|
||||||
|
def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) -> float:
|
||||||
|
"""
|
||||||
|
Compute expected mean g-value after watermarking, assuming uniform LM dist.
|
||||||
|
|
||||||
|
This is the theoretical expected value for single-layer watermarking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`):
|
||||||
|
The size of the vocabulary.
|
||||||
|
coinflip_prob arg_name (`float`, *optional*, defaults to 0.5):
|
||||||
|
Probability of 1 in boolean prf.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The expected mean g-value for watermarked text.
|
||||||
|
"""
|
||||||
|
return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))
|
||||||
|
@ -92,7 +92,6 @@ from .logits_process import (
|
|||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
WatermarkLogitsProcessor,
|
|
||||||
)
|
)
|
||||||
from .stopping_criteria import (
|
from .stopping_criteria import (
|
||||||
ConfidenceCriteria,
|
ConfidenceCriteria,
|
||||||
@ -1011,15 +1010,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
if generation_config.watermarking_config is not None:
|
if generation_config.watermarking_config is not None:
|
||||||
processors.append(
|
processors.append(
|
||||||
WatermarkLogitsProcessor(
|
generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
device=device,
|
|
||||||
greenlist_ratio=generation_config.watermarking_config.greenlist_ratio,
|
|
||||||
bias=generation_config.watermarking_config.bias,
|
|
||||||
hashing_key=generation_config.watermarking_config.hashing_key,
|
|
||||||
seeding_scheme=generation_config.watermarking_config.seeding_scheme,
|
|
||||||
context_width=generation_config.watermarking_config.context_width,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO (joao): find a strategy to specify the order of the processors
|
# TODO (joao): find a strategy to specify the order of the processors
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 The HuggingFace Inc. team
|
# Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -16,19 +16,22 @@
|
|||||||
import collections
|
import collections
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Dict, Optional, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import BCELoss
|
||||||
|
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..modeling_utils import PreTrainedModel
|
||||||
from ..utils import is_torch_available, logging
|
from ..utils import ModelOutput, is_torch_available, logging
|
||||||
from .configuration_utils import WatermarkingConfig
|
from .configuration_utils import PretrainedConfig, WatermarkingConfig
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .logits_process import WatermarkLogitsProcessor
|
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -237,3 +240,310 @@ class WatermarkDetector:
|
|||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
)
|
)
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
|
class BayesianDetectorConfig(PretrainedConfig):
|
||||||
|
"""
|
||||||
|
This is the configuration class to store the configuration of a [`BayesianDetectorModel`]. It is used to
|
||||||
|
instantiate a Bayesian Detector model according to the specified arguments.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
watermarking_depth (`int`, *optional*):
|
||||||
|
The number of tournament layers.
|
||||||
|
base_rate (`float1`, *optional*, defaults to 0.5):
|
||||||
|
Prior probability P(w) that a text is watermarked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, watermarking_depth: int = None, base_rate: float = 0.5, **kwargs):
|
||||||
|
self.watermarking_depth = watermarking_depth
|
||||||
|
self.base_rate = base_rate
|
||||||
|
# These can be set later to store information about this detector.
|
||||||
|
self.model_name = None
|
||||||
|
self.watermarking_config = None
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def set_detector_information(self, model_name, watermarking_config):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.watermarking_config = watermarking_config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BayesianWatermarkDetectorModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of models predicting if the text is watermarked.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||||
|
Language modeling loss.
|
||||||
|
posterior_probabilities (`torch.FloatTensor` of shape `(1,)`):
|
||||||
|
Multiple choice classification loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
posterior_probabilities: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BayesianDetectorWatermarkedLikelihood(nn.Module):
|
||||||
|
"""Watermarked likelihood model for binary-valued g-values.
|
||||||
|
|
||||||
|
This takes in g-values and returns p(g_values|watermarked).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, watermarking_depth: int):
|
||||||
|
"""Initializes the model parameters."""
|
||||||
|
super().__init__()
|
||||||
|
self.watermarking_depth = watermarking_depth
|
||||||
|
self.beta = torch.nn.Parameter(-2.5 + 0.001 * torch.randn(1, 1, watermarking_depth))
|
||||||
|
self.delta = torch.nn.Parameter(0.001 * torch.randn(1, 1, self.watermarking_depth, watermarking_depth))
|
||||||
|
|
||||||
|
def _compute_latents(self, g_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Computes the unique token probability distribution given g-values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
|
||||||
|
PRF values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
p_one_unique_token and p_two_unique_tokens, both of shape
|
||||||
|
[batch_size, seq_len, watermarking_depth]. p_one_unique_token[i,t,l]
|
||||||
|
gives the probability of there being one unique token in a tournament
|
||||||
|
match on layer l, on timestep t, for batch item i.
|
||||||
|
p_one_unique_token[i,t,l] + p_two_unique_token[i,t,l] = 1.
|
||||||
|
"""
|
||||||
|
# Tile g-values to produce feature vectors for predicting the latents
|
||||||
|
# for each layer in the tournament; our model for the latents psi is a
|
||||||
|
# logistic regression model psi = sigmoid(delta * x + beta).
|
||||||
|
|
||||||
|
# [batch_size, seq_len, watermarking_depth, watermarking_depth]
|
||||||
|
x = torch.repeat_interleave(torch.unsqueeze(g_values, dim=-2), self.watermarking_depth, axis=-2)
|
||||||
|
|
||||||
|
# mask all elements above -1 diagonal for autoregressive factorization
|
||||||
|
x = torch.tril(x, diagonal=-1)
|
||||||
|
|
||||||
|
# [batch_size, seq_len, watermarking_depth]
|
||||||
|
# (i, j, k, l) x (i, j, k, l) -> (i, j, k) einsum equivalent
|
||||||
|
logits = (self.delta[..., None, :] @ x.type(self.delta.dtype)[..., None]).squeeze() + self.beta
|
||||||
|
|
||||||
|
p_two_unique_tokens = torch.sigmoid(logits)
|
||||||
|
p_one_unique_token = 1 - p_two_unique_tokens
|
||||||
|
return p_one_unique_token, p_two_unique_tokens
|
||||||
|
|
||||||
|
def forward(self, g_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Computes the likelihoods P(g_values|watermarked).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
|
||||||
|
g-values (values 0 or 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
p(g_values|watermarked) of shape [batch_size, seq_len, watermarking_depth].
|
||||||
|
"""
|
||||||
|
p_one_unique_token, p_two_unique_tokens = self._compute_latents(g_values)
|
||||||
|
|
||||||
|
# P(g_tl | watermarked) is equal to
|
||||||
|
# 0.5 * [ (g_tl+0.5) * p_two_unique_tokens + p_one_unique_token].
|
||||||
|
return 0.5 * ((g_values + 0.5) * p_two_unique_tokens + p_one_unique_token)
|
||||||
|
|
||||||
|
|
||||||
|
class BayesianDetectorModel(PreTrainedModel):
|
||||||
|
r"""
|
||||||
|
Bayesian classifier for watermark detection.
|
||||||
|
|
||||||
|
This detector uses Bayes' rule to compute a watermarking score, which is the sigmoid of the log of ratio of the
|
||||||
|
posterior probabilities P(watermarked|g_values) and P(unwatermarked|g_values). Please see the section on
|
||||||
|
BayesianScore in the paper for further details.
|
||||||
|
Paper URL: https://www.nature.com/articles/s41586-024-08025-4
|
||||||
|
|
||||||
|
Note that this detector only works with non-distortionary Tournament-based watermarking using the Bernoulli(0.5)
|
||||||
|
g-value distribution.
|
||||||
|
|
||||||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
|
etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`BayesianDetectorConfig`]): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = BayesianDetectorConfig
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.watermarking_depth = config.watermarking_depth
|
||||||
|
self.base_rate = config.base_rate
|
||||||
|
self.likelihood_model_watermarked = BayesianDetectorWatermarkedLikelihood(
|
||||||
|
watermarking_depth=self.watermarking_depth
|
||||||
|
)
|
||||||
|
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize the weights."""
|
||||||
|
if isinstance(module, nn.Parameter):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def _compute_posterior(
|
||||||
|
self,
|
||||||
|
likelihoods_watermarked: torch.Tensor,
|
||||||
|
likelihoods_unwatermarked: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
prior: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute posterior P(w|g) given likelihoods, mask and prior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
likelihoods_watermarked (`torch.Tensor` of shape `(batch, length, depth)`):
|
||||||
|
Likelihoods P(g_values|watermarked) of g-values under watermarked model.
|
||||||
|
likelihoods_unwatermarked (`torch.Tensor` of shape `(batch, length, depth)`):
|
||||||
|
Likelihoods P(g_values|unwatermarked) of g-values under unwatermarked model.
|
||||||
|
mask (`torch.Tensor` of shape `(batch, length)`):
|
||||||
|
A binary array indicating which g-values should be used. g-values with mask value 0 are discarded.
|
||||||
|
prior (`float`):
|
||||||
|
the prior probability P(w) that the text is watermarked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Posterior probability P(watermarked|g_values), shape [batch].
|
||||||
|
"""
|
||||||
|
mask = torch.unsqueeze(mask, dim=-1)
|
||||||
|
prior = torch.clamp(prior, min=1e-5, max=1 - 1e-5)
|
||||||
|
log_likelihoods_watermarked = torch.log(torch.clamp(likelihoods_watermarked, min=1e-30, max=float("inf")))
|
||||||
|
log_likelihoods_unwatermarked = torch.log(torch.clamp(likelihoods_unwatermarked, min=1e-30, max=float("inf")))
|
||||||
|
log_odds = log_likelihoods_watermarked - log_likelihoods_unwatermarked
|
||||||
|
|
||||||
|
# Sum relative surprisals (log odds) across all token positions and layers.
|
||||||
|
relative_surprisal_likelihood = torch.einsum("i...->i", log_odds * mask)
|
||||||
|
|
||||||
|
# Compute the relative surprisal prior
|
||||||
|
relative_surprisal_prior = torch.log(prior) - torch.log(1 - prior)
|
||||||
|
|
||||||
|
# Combine prior and likelihood.
|
||||||
|
# [batch_size]
|
||||||
|
relative_surprisal = relative_surprisal_prior + relative_surprisal_likelihood
|
||||||
|
|
||||||
|
# Compute the posterior probability P(w|g) = sigmoid(relative_surprisal).
|
||||||
|
return torch.sigmoid(relative_surprisal)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
g_values: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
loss_batch_weight=1,
|
||||||
|
return_dict=False,
|
||||||
|
) -> BayesianWatermarkDetectorModelOutput:
|
||||||
|
"""
|
||||||
|
Computes the watermarked posterior P(watermarked|g_values).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth, ...)`):
|
||||||
|
g-values (with values 0 or 1)
|
||||||
|
mask:
|
||||||
|
A binary array shape [batch_size, seq_len] indicating which g-values should be used. g-values with mask
|
||||||
|
value 0 are discarded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
p(watermarked | g_values), of shape [batch_size].
|
||||||
|
"""
|
||||||
|
|
||||||
|
likelihoods_watermarked = self.likelihood_model_watermarked(g_values)
|
||||||
|
likelihoods_unwatermarked = 0.5 * torch.ones_like(g_values)
|
||||||
|
out = self._compute_posterior(
|
||||||
|
likelihoods_watermarked=likelihoods_watermarked,
|
||||||
|
likelihoods_unwatermarked=likelihoods_unwatermarked,
|
||||||
|
mask=mask,
|
||||||
|
prior=self.prior,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = BCELoss()
|
||||||
|
loss_unwweight = torch.sum(self.likelihood_model_watermarked.delta**2)
|
||||||
|
loss_weight = loss_unwweight * loss_batch_weight
|
||||||
|
loss = loss_fct(torch.clamp(out, 1e-5, 1 - 1e-5), labels) + loss_weight
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (out,) if loss is None else (out, loss)
|
||||||
|
|
||||||
|
return BayesianWatermarkDetectorModelOutput(loss=loss, posterior_probabilities=out)
|
||||||
|
|
||||||
|
|
||||||
|
class SynthIDTextWatermarkDetector:
|
||||||
|
r"""
|
||||||
|
SynthID text watermark detector class.
|
||||||
|
|
||||||
|
This class has to be initialized with the trained bayesian detector module check script
|
||||||
|
in examples/synthid_text/detector_training.py for example in training/saving/loading this
|
||||||
|
detector module. The folder also showcases example use case of this detector.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
detector_module ([`BayesianDetectorModel`]):
|
||||||
|
Bayesian detector module object initialized with parameters.
|
||||||
|
Check examples/research_projects/synthid_text/detector_training.py for usage.
|
||||||
|
logits_processor (`SynthIDTextWatermarkLogitsProcessor`):
|
||||||
|
The logits processor used for watermarking.
|
||||||
|
tokenizer (`Any`):
|
||||||
|
The tokenizer used for the model.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> from transformers import (
|
||||||
|
... AutoTokenizer, BayesianDetectorModel, SynthIDTextWatermarkLogitsProcessor, SynthIDTextWatermarkDetector
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Load the detector. See examples/research_projects/synthid_text for training a detector.
|
||||||
|
>>> detector_model = BayesianDetectorModel.from_pretrained("joaogante/dummy_synthid_detector")
|
||||||
|
>>> logits_processor = SynthIDTextWatermarkLogitsProcessor(
|
||||||
|
... **detector_model.config.watermarking_config, device="cpu"
|
||||||
|
... )
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(detector_model.config.model_name)
|
||||||
|
>>> detector = SynthIDTextWatermarkDetector(detector_model, logits_processor, tokenizer)
|
||||||
|
|
||||||
|
>>> # Test whether a certain string is watermarked
|
||||||
|
>>> test_input = tokenizer(["This is a test input"], return_tensors="pt")
|
||||||
|
>>> is_watermarked = detector(test_input.input_ids)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
detector_module: BayesianDetectorModel,
|
||||||
|
logits_processor: SynthIDTextWatermarkLogitsProcessor,
|
||||||
|
tokenizer: Any,
|
||||||
|
):
|
||||||
|
self.detector_module = detector_module
|
||||||
|
self.logits_processor = logits_processor
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(self, tokenized_outputs: torch.Tensor):
|
||||||
|
# eos mask is computed, skip first ngram_len - 1 tokens
|
||||||
|
# eos_mask will be of shape [batch_size, output_len]
|
||||||
|
eos_token_mask = self.logits_processor.compute_eos_token_mask(
|
||||||
|
input_ids=tokenized_outputs,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
)[:, self.logits_processor.ngram_len - 1 :]
|
||||||
|
|
||||||
|
# context repetition mask is computed
|
||||||
|
context_repetition_mask = self.logits_processor.compute_context_repetition_mask(
|
||||||
|
input_ids=tokenized_outputs,
|
||||||
|
)
|
||||||
|
# context repitition mask shape [batch_size, output_len - (ngram_len - 1)]
|
||||||
|
|
||||||
|
combined_mask = context_repetition_mask * eos_token_mask
|
||||||
|
|
||||||
|
g_values = self.logits_processor.compute_g_values(
|
||||||
|
input_ids=tokenized_outputs,
|
||||||
|
)
|
||||||
|
# g values shape [batch_size, output_len - (ngram_len - 1), depth]
|
||||||
|
return self.detector_module(g_values, combined_mask)
|
||||||
|
@ -191,6 +191,20 @@ class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BayesianDetectorConfig(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BayesianDetectorModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class BeamScorer(metaclass=DummyObject):
|
class BeamScorer(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@ -457,6 +471,27 @@ class SuppressTokensLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SynthIDTextWatermarkDetector(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SynthIDTextWatermarkingConfig(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SynthIDTextWatermarkLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class TemperatureLogitsWarper(metaclass=DummyObject):
|
class TemperatureLogitsWarper(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
@ -48,6 +49,7 @@ if is_torch_available():
|
|||||||
PrefixConstrainedLogitsProcessor,
|
PrefixConstrainedLogitsProcessor,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
SequenceBiasLogitsProcessor,
|
SequenceBiasLogitsProcessor,
|
||||||
|
SynthIDTextWatermarkLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
@ -975,3 +977,187 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores_wo_bias = scores[:, -1].clone()
|
scores_wo_bias = scores[:, -1].clone()
|
||||||
out = watermark(input_ids=input_ids, scores=scores)
|
out = watermark(input_ids=input_ids, scores=scores)
|
||||||
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
|
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
|
||||||
|
|
||||||
|
@parameterized.expand([(5, 3, 10000), (10, 5, 1000)])
|
||||||
|
def test_synthidtext_watermarking_processor_bias_uniformity(self, ngram_len, num_layers, vocab_size):
|
||||||
|
"""Test SynthID watermarked distribution bias uniformity over iterations."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
watermarking_config = {
|
||||||
|
"ngram_len": ngram_len,
|
||||||
|
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
|
||||||
|
"sampling_table_size": 2**16,
|
||||||
|
"sampling_table_seed": 0,
|
||||||
|
"context_history_size": 512,
|
||||||
|
"device": torch_device,
|
||||||
|
}
|
||||||
|
batch_size = 100000
|
||||||
|
ngrams = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, ngram_len),
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||||
|
g_values = logits_processor.compute_g_values(ngrams)
|
||||||
|
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=0))
|
||||||
|
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.01)
|
||||||
|
|
||||||
|
@parameterized.expand([(10000, 3), (1000, 20)])
|
||||||
|
def test_synthidtext_watermark_processor_bias_uniformity_across_vocab(self, vocab_size, num_layers):
|
||||||
|
"""Test SynthID watermarked distribution bias uniformity over vocabs of the model."""
|
||||||
|
batch_size = 1000
|
||||||
|
ngram_len = 5
|
||||||
|
torch.manual_seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
watermarking_config = {
|
||||||
|
"ngram_len": ngram_len,
|
||||||
|
"keys": np.random.randint(low=0, high=2**16, size=(num_layers,)),
|
||||||
|
"sampling_table_size": 2**16,
|
||||||
|
"sampling_table_seed": 0,
|
||||||
|
"context_history_size": 512,
|
||||||
|
"device": torch_device,
|
||||||
|
}
|
||||||
|
n_minus_1_grams = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, watermarking_config["ngram_len"] - 1),
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||||
|
ngram_keys, _ = logits_processor._compute_keys(
|
||||||
|
n_minus_1_grams,
|
||||||
|
torch.stack([torch.arange(vocab_size, device=torch_device) for _ in range(batch_size)]),
|
||||||
|
)
|
||||||
|
|
||||||
|
g_values = logits_processor.sample_g_values(ngram_keys)
|
||||||
|
# g_values shape should be [batch_size, vocab_size, num_layers]
|
||||||
|
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1))
|
||||||
|
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001)
|
||||||
|
|
||||||
|
@parameterized.expand([(2, "uniform"), (10, "uniform"), (2, "random"), (10, "random")])
|
||||||
|
def test_synthidtext_watermark_processor_distributional_convergence(self, vocab_size, logits_type):
|
||||||
|
"""Check if watermarked distribution converges to unwatermarked logits distribution."""
|
||||||
|
batch_size = 1500
|
||||||
|
num_keys = 1000
|
||||||
|
|
||||||
|
updated_softmaxes = 0
|
||||||
|
np.random.seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
if logits_type == "uniform":
|
||||||
|
fixed_logits = torch.ones((batch_size, vocab_size), device=torch_device)
|
||||||
|
elif logits_type == "random":
|
||||||
|
fixed_logits = torch.rand(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
vocab_size,
|
||||||
|
),
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
fixed_logits = fixed_logits.repeat(batch_size, 1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized logits_type {logits_type}")
|
||||||
|
for _ in range(num_keys):
|
||||||
|
watermarking_config = {
|
||||||
|
"ngram_len": 5,
|
||||||
|
"keys": np.random.randint(0, 10**9, size=(1,), dtype=np.int64),
|
||||||
|
"sampling_table_size": 2**16,
|
||||||
|
"sampling_table_seed": 0,
|
||||||
|
"context_history_size": 1024,
|
||||||
|
"device": torch_device,
|
||||||
|
}
|
||||||
|
|
||||||
|
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config)
|
||||||
|
|
||||||
|
ngrams = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, watermarking_config["ngram_len"]),
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert ngram-1 into logit_processor state.
|
||||||
|
for idx in range(watermarking_config["ngram_len"] - 1):
|
||||||
|
_ = logits_processor(ngrams[:, :idx], fixed_logits)
|
||||||
|
|
||||||
|
updated_scores = logits_processor(ngrams, fixed_logits)
|
||||||
|
updated_softmaxes += torch.nn.functional.softmax(updated_scores, dim=1).cpu().numpy()
|
||||||
|
|
||||||
|
updated_softmaxes = np.mean(updated_softmaxes, axis=0) / num_keys
|
||||||
|
is_close = torch.all(
|
||||||
|
torch.isclose(
|
||||||
|
torch.tensor(updated_softmaxes, device=torch_device),
|
||||||
|
torch.nn.Softmax()(fixed_logits[0]), # Take any batch entry, all are same.
|
||||||
|
atol=1e-3,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(is_close)
|
||||||
|
|
||||||
|
@parameterized.expand([(2, 10, 1, 0.01), (100, 5, 1, 0.01), (100, 10, 2, 0.02)])
|
||||||
|
def test_synthidtext_watermark_processor_bias_test(self, vocab_size, ngram_len, num_layers, atol):
|
||||||
|
"""Test SynthID watermarking bias matches theoretical value."""
|
||||||
|
batch_size = 20000
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
keys = [np.random.randint(0, 10**9) for _ in range(num_layers)]
|
||||||
|
# Use 10**9 rather than vocab_size to ensure variety in (n-1)-grams.
|
||||||
|
context = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=10**9,
|
||||||
|
size=(batch_size, ngram_len - 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
generator=generator,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
context_history_size = 1024
|
||||||
|
logits_processor = SynthIDTextWatermarkLogitsProcessor(
|
||||||
|
ngram_len=ngram_len,
|
||||||
|
keys=keys,
|
||||||
|
sampling_table_size=2**16,
|
||||||
|
sampling_table_seed=0,
|
||||||
|
context_history_size=context_history_size,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = torch.ones(
|
||||||
|
(batch_size, vocab_size),
|
||||||
|
dtype=torch.float64,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
# Init state of the logits processor.
|
||||||
|
logits_processor(context, scores)
|
||||||
|
# insert context into the state.
|
||||||
|
for idx in range(1, ngram_len - 1):
|
||||||
|
_ = logits_processor(context[:, :idx], scores)
|
||||||
|
|
||||||
|
updated_scores = logits_processor(context, scores)
|
||||||
|
|
||||||
|
probs = torch.nn.functional.softmax(updated_scores, dim=1)
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
next_tokens = torch.multinomial(
|
||||||
|
probs,
|
||||||
|
num_samples=1,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
ngrams = torch.concat((context, next_tokens), dim=1)
|
||||||
|
g_values = logits_processor.compute_g_values(ngrams)
|
||||||
|
mean_g_values = g_values.mean(dtype=torch.float64, dim=(0, 1))
|
||||||
|
|
||||||
|
expected_mean_g_value = logits_processor.expected_mean_g_value(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
|
is_close = torch.all(
|
||||||
|
torch.isclose(
|
||||||
|
mean_g_values,
|
||||||
|
torch.tensor(expected_mean_g_value, dtype=torch.float64, device=torch_device),
|
||||||
|
atol=atol,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(is_close)
|
||||||
|
@ -84,6 +84,7 @@ if is_torch_available():
|
|||||||
SampleEncoderDecoderOutput,
|
SampleEncoderDecoderOutput,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
|
SynthIDTextWatermarkingConfig,
|
||||||
WatermarkDetector,
|
WatermarkDetector,
|
||||||
WatermarkingConfig,
|
WatermarkingConfig,
|
||||||
)
|
)
|
||||||
@ -2517,9 +2518,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_watermark_generation(self):
|
def test_green_red_watermark_generation(self):
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||||
input_len = model_inputs["input_ids"].shape[-1]
|
input_len = model_inputs["input_ids"].shape[-1]
|
||||||
@ -2548,6 +2549,61 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
|
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
|
||||||
self.assertListEqual(detection_out.prediction.tolist(), [False])
|
self.assertListEqual(detection_out.prediction.tolist(), [False])
|
||||||
|
|
||||||
|
"""Check the mean bias inserted by the watermarking algorithm."""
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_synthid_text_watermark_generation_mean_expected_bias(self):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
|
||||||
|
input_len = 5
|
||||||
|
batch_size = 200
|
||||||
|
|
||||||
|
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
|
||||||
|
watermark_config = SynthIDTextWatermarkingConfig(keys=[10, 20], ngram_len=5, debug_mode=True)
|
||||||
|
logits_processor = watermark_config.construct_processor(model.config.vocab_size, torch_device)
|
||||||
|
mean_g_values_repeats = []
|
||||||
|
for _ in range(40):
|
||||||
|
input_ids = torch.zeros(
|
||||||
|
(batch_size, input_len),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
model_inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": torch.ones_like(input_ids, device=torch_device),
|
||||||
|
}
|
||||||
|
output = model.generate(
|
||||||
|
**model_inputs, watermarking_config=watermark_config, do_sample=True, max_length=500, top_k=1000
|
||||||
|
)
|
||||||
|
g_values = logits_processor.compute_g_values(input_ids=output[:, input_len:])
|
||||||
|
context_repetition_mask = logits_processor.compute_context_repetition_mask(
|
||||||
|
input_ids=output[:, input_len:],
|
||||||
|
).unsqueeze(dim=2)
|
||||||
|
|
||||||
|
mean_g_values = torch.masked.mean(
|
||||||
|
g_values,
|
||||||
|
mask=context_repetition_mask,
|
||||||
|
dim=0,
|
||||||
|
keepdim=True,
|
||||||
|
dtype=torch.float64,
|
||||||
|
)
|
||||||
|
mean_g_values_repeats.append(mean_g_values)
|
||||||
|
|
||||||
|
mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0)
|
||||||
|
expected_mean_g_value = logits_processor.expected_mean_g_value(
|
||||||
|
vocab_size=model.config.vocab_size,
|
||||||
|
)
|
||||||
|
atol = 0.03
|
||||||
|
is_close = torch.isclose(
|
||||||
|
mean_g_values,
|
||||||
|
torch.tensor(expected_mean_g_value, dtype=torch.float64),
|
||||||
|
atol=atol,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.all(is_close))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_beam_search_example_integration(self):
|
def test_beam_search_example_integration(self):
|
||||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||||
|
Loading…
Reference in New Issue
Block a user