transformers/examples/research_projects/synthid_text/detector_training.py
Joao Gante b0f0c61899
Add SynthID (watermerking by Google DeepMind) (#34350)
* Add SynthIDTextWatermarkLogitsProcessor

* esolving comments.

* Resolving comments.

* esolving commits,

* Improving SynthIDWatermark tests.

* switch to PT version

* detector as pretrained model + style

* update training + style

* rebase

* Update logits_process.py

* Improving SynthIDWatermark tests.

* Shift detector training to wikitext negatives and stabilize with lower learning rate.

* Clean up.

* in for 7B

* cleanup

* upport python 3.8.

* README and final cleanup.

* HF Hub upload and initiaze.

* Update requirements for synthid_text.

* Adding SynthIDTextWatermarkDetector.

* Detector testing.

* Documentation changes.

* Copyrights fix.

* Fix detector api.

* ironing out errors

* ironing out errors

* training checks

* make fixup and make fix-copies

* docstrings and add to docs

* copyright

* BC

* test docstrings

* move import

* protect type hints

* top level imports

* watermarking example

* direct imports

* tpr fpr meaning

* process_kwargs

* SynthIDTextWatermarkingConfig docstring

* assert -> exception

* example updates

* no immutable dict (cant be serialized)

* pack fn

* einsum equivalent

* import order

* fix test on gpu

* add detector example

---------

Co-authored-by: Sumedh Ghaisas <sumedhg@google.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: sumedhghaisas2 <138781311+sumedhghaisas2@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
2024-10-23 21:18:52 +01:00

503 lines
17 KiB
Python

# coding=utf-8
# Copyright 2024 Google DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import dataclasses
import enum
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BayesianDetectorConfig,
BayesianDetectorModel,
SynthIDTextWatermarkDetector,
SynthIDTextWatermarkingConfig,
SynthIDTextWatermarkLogitsProcessor,
)
from utils import (
get_tokenized_uwm_outputs,
get_tokenized_wm_outputs,
process_raw_model_outputs,
update_fn_if_fpr_tpr,
upload_model_to_hf,
)
@enum.unique
class ValidationMetric(enum.Enum):
"""Direction along the z-axis."""
TPR_AT_FPR = "tpr_at_fpr"
CROSS_ENTROPY = "cross_entropy"
@dataclasses.dataclass
class TrainingArguments:
"""Training arguments pertaining to the training loop itself."""
eval_metric: Optional[str] = dataclasses.field(
default=ValidationMetric.TPR_AT_FPR, metadata={"help": "The evaluation metric used."}
)
def train_detector(
detector: torch.nn.Module,
g_values: torch.Tensor,
mask: torch.Tensor,
watermarked: torch.Tensor,
epochs: int = 250,
learning_rate: float = 1e-3,
minibatch_size: int = 64,
seed: int = 0,
l2_weight: float = 0.0,
shuffle: bool = True,
g_values_val: Optional[torch.Tensor] = None,
mask_val: Optional[torch.Tensor] = None,
watermarked_val: Optional[torch.Tensor] = None,
verbose: bool = False,
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
) -> Tuple[Dict[str, Any], float]:
"""Trains a Bayesian detector model.
Args:
g_values: g-values of shape [num_train, seq_len, watermarking_depth].
mask: A binary array shape [num_train, seq_len] indicating which g-values
should be used. g-values with mask value 0 are discarded.
watermarked: A binary array of shape [num_train] indicating whether the
example is watermarked (0: unwatermarked, 1: watermarked).
epochs: Number of epochs to train for.
learning_rate: Learning rate for optimizer.
minibatch_size: Minibatch size for training. Note that a minibatch
requires ~ 32 * minibatch_size * seq_len * watermarked_depth *
watermarked_depth bits of memory.
seed: Seed for parameter initialization.
l2_weight: Weight to apply to L2 regularization for delta parameters.
shuffle: Whether to shuffle before training.
g_values_val: Validation g-values of shape [num_val, seq_len,
watermarking_depth].
mask_val: Validation mask of shape [num_val, seq_len].
watermarked_val: Validation watermark labels of shape [num_val].
verbose: Boolean indicating verbosity of training. If true, the loss will
be printed. Defaulted to False.
use_tpr_fpr_for_val: Whether to use TPR@FPR=1% as metric for validation.
If false, use cross entropy loss.
Returns:
Tuple of
training_history: Training history keyed by epoch number where the
values are
dictionaries containing the loss, validation loss, and model
parameters,
keyed by
'loss', 'val_loss', and 'params', respectively.
min_val_loss: Minimum validation loss achieved during training.
"""
# Set the random seed for reproducibility
torch.manual_seed(seed)
# Shuffle the data if required
if shuffle:
indices = torch.randperm(len(g_values))
g_values = g_values[indices]
mask = mask[indices]
watermarked = watermarked[indices]
# Initialize optimizer
optimizer = torch.optim.Adam(detector.parameters(), lr=learning_rate)
history = {}
min_val_loss = float("inf")
for epoch in range(epochs):
losses = []
detector.train()
num_batches = len(g_values) // minibatch_size
for i in range(0, len(g_values), minibatch_size):
end = i + minibatch_size
if end > len(g_values):
break
loss_batch_weight = l2_weight / num_batches
optimizer.zero_grad()
loss = detector(
g_values=g_values[i:end],
mask=mask[i:end],
labels=watermarked[i:end],
loss_batch_weight=loss_batch_weight,
)[1]
loss.backward()
optimizer.step()
losses.append(loss.item())
train_loss = sum(losses) / len(losses)
val_losses = []
if g_values_val is not None:
detector.eval()
if validation_metric == ValidationMetric.TPR_AT_FPR:
val_loss = update_fn_if_fpr_tpr(
detector,
g_values_val,
mask_val,
watermarked_val,
minibatch_size=minibatch_size,
)
else:
for i in range(0, len(g_values_val), minibatch_size):
end = i + minibatch_size
if end > len(g_values_val):
break
with torch.no_grad():
v_loss = detector(
g_values=g_values_val[i:end],
mask=mask_val[i:end],
labels=watermarked_val[i:end],
loss_batch_weight=0,
)[1]
val_losses.append(v_loss.item())
val_loss = sum(val_losses) / len(val_losses)
# Store training history
history[epoch + 1] = {"loss": train_loss, "val_loss": val_loss}
if verbose:
if val_loss is not None:
print(f"Epoch {epoch}: loss {loss} (train), {val_loss} (val)")
else:
print(f"Epoch {epoch}: loss {loss} (train)")
if val_loss is not None and val_loss < min_val_loss:
min_val_loss = val_loss
best_val_epoch = epoch
if verbose:
print(f"Best val Epoch: {best_val_epoch}, min_val_loss: {min_val_loss}")
return history, min_val_loss
def train_best_detector(
tokenized_wm_outputs: Union[List[np.ndarray], np.ndarray],
tokenized_uwm_outputs: Union[List[np.ndarray], np.ndarray],
logits_processor: SynthIDTextWatermarkLogitsProcessor,
tokenizer: Any,
torch_device: torch.device,
test_size: float = 0.3,
pos_truncation_length: Optional[int] = 200,
neg_truncation_length: Optional[int] = 100,
max_padded_length: int = 2300,
n_epochs: int = 50,
learning_rate: float = 2.1e-2,
l2_weights: np.ndarray = np.logspace(-3, -2, num=4),
verbose: bool = False,
validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR,
):
"""Train and return the best detector given range of hyperparameters.
In practice, we have found that tuning pos_truncation_length,
neg_truncation_length, n_epochs, learning_rate and l2_weights can help
improve the performance of the detector. We reccommend tuning these
parameters for your data.
"""
l2_weights = list(l2_weights)
(
train_g_values,
train_masks,
train_labels,
cv_g_values,
cv_masks,
cv_labels,
) = process_raw_model_outputs(
logits_processor,
tokenizer,
pos_truncation_length,
neg_truncation_length,
max_padded_length,
tokenized_wm_outputs,
test_size,
tokenized_uwm_outputs,
torch_device,
)
best_detector = None
lowest_loss = float("inf")
val_losses = []
for l2_weight in l2_weights:
config = BayesianDetectorConfig(watermarking_depth=len(logits_processor.keys))
detector = BayesianDetectorModel(config).to(torch_device)
_, min_val_loss = train_detector(
detector=detector,
g_values=train_g_values,
mask=train_masks,
watermarked=train_labels,
g_values_val=cv_g_values,
mask_val=cv_masks,
watermarked_val=cv_labels,
learning_rate=learning_rate,
l2_weight=l2_weight,
epochs=n_epochs,
verbose=verbose,
validation_metric=validation_metric,
)
val_losses.append(min_val_loss)
if min_val_loss < lowest_loss:
lowest_loss = min_val_loss
best_detector = detector
return best_detector, lowest_loss
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="google/gemma-2b-it",
help=("LM model to train the detector for."),
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help=("Temperature to sample from the model."),
)
parser.add_argument(
"--top_k",
type=int,
default=40,
help=("Top K for sampling."),
)
parser.add_argument(
"--top_p",
type=float,
default=1.0,
help=("Top P for sampling."),
)
parser.add_argument(
"--num_negatives",
type=int,
default=10000,
help=("Number of negatives for detector training."),
)
parser.add_argument(
"--pos_batch_size",
type=int,
default=32,
help=("Batch size of watermarked positives while sampling."),
)
parser.add_argument(
"--num_pos_batch",
type=int,
default=313,
help=("Number of positive batches for training."),
)
parser.add_argument(
"--generation_length",
type=int,
default=512,
help=("Generation length for sampling."),
)
parser.add_argument(
"--save_model_to_hf_hub",
action="store_true",
help=("Whether to save the trained model HF hub. By default it will be a private repo."),
)
parser.add_argument(
"--load_from_hf_hub",
action="store_true",
help=(
"Whether to load trained detector model from HF Hub, make sure its the model trained on the same model "
"we are loading in the script."
),
)
parser.add_argument(
"--hf_hub_model_name",
type=str,
default=None,
help=("HF hub model name for loading of saving the model."),
)
parser.add_argument(
"--eval_detector_on_prompts",
action="store_true",
help=("Evaluate detector on a prompt and print probability of watermark."),
)
args = parser.parse_args()
model_name = args.model_name
temperature = args.temperature
top_k = args.top_k
top_p = args.top_p
num_negatives = args.num_negatives
pos_batch_size = args.pos_batch_size
num_pos_batch = args.num_pos_batch
if num_pos_batch < 10:
raise ValueError("--num_pos_batch should be greater than 10.")
generation_length = args.generation_length
save_model_to_hf_hub = args.save_model_to_hf_hub
load_from_hf_hub = args.load_from_hf_hub
repo_name = args.hf_hub_model_name
eval_detector_on_prompts = args.eval_detector_on_prompts
NEG_BATCH_SIZE = 32
# Truncate outputs to this length for training.
POS_TRUNCATION_LENGTH = 200
NEG_TRUNCATION_LENGTH = 100
# Pad trucated outputs to this length for equal shape across all batches.
MAX_PADDED_LENGTH = 1000
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
if DEVICE.type not in ("cuda", "tpu"):
raise ValueError("We have found the training stable on GPU and TPU, we are working on" " a fix for CPUs")
model = None
if not load_from_hf_hub:
# Change this to make your watermark unique. Check documentation in the paper to understand the
# impact of these parameters.
DEFAULT_WATERMARKING_CONFIG = {
"ngram_len": 5, # This corresponds to H=4 context window size in the paper.
"keys": [
654,
400,
836,
123,
340,
443,
597,
160,
57,
29,
590,
639,
13,
715,
468,
990,
966,
226,
324,
585,
118,
504,
421,
521,
129,
669,
732,
225,
90,
960,
],
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 1024,
}
watermark_config = SynthIDTextWatermarkingConfig(**DEFAULT_WATERMARKING_CONFIG)
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
logits_processor = SynthIDTextWatermarkLogitsProcessor(**DEFAULT_WATERMARKING_CONFIG, device=DEVICE)
tokenized_wm_outputs = get_tokenized_wm_outputs(
model,
tokenizer,
watermark_config,
num_pos_batch,
pos_batch_size,
temperature,
generation_length,
top_k,
top_p,
DEVICE,
)
tokenized_uwm_outputs = get_tokenized_uwm_outputs(num_negatives, NEG_BATCH_SIZE, tokenizer, DEVICE)
best_detector, lowest_loss = train_best_detector(
tokenized_wm_outputs=tokenized_wm_outputs,
tokenized_uwm_outputs=tokenized_uwm_outputs,
logits_processor=logits_processor,
tokenizer=tokenizer,
torch_device=DEVICE,
test_size=0.3,
pos_truncation_length=POS_TRUNCATION_LENGTH,
neg_truncation_length=NEG_TRUNCATION_LENGTH,
max_padded_length=MAX_PADDED_LENGTH,
n_epochs=100,
learning_rate=3e-3,
l2_weights=[
0,
],
verbose=True,
validation_metric=ValidationMetric.TPR_AT_FPR,
)
else:
if repo_name is None:
raise ValueError("When loading from pretrained detector model name cannot be None.")
best_detector = BayesianDetectorModel.from_pretrained(repo_name).to(DEVICE)
best_detector.config.set_detector_information(
model_name=model_name, watermarking_config=DEFAULT_WATERMARKING_CONFIG
)
if save_model_to_hf_hub:
upload_model_to_hf(best_detector, repo_name)
# Evaluate model response with the detector
if eval_detector_on_prompts:
model_name = best_detector.config.model_name
watermark_config_dict = best_detector.config.watermarking_config
logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermark_config_dict, device=DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
synthid_text_detector = SynthIDTextWatermarkDetector(best_detector, logits_processor, tokenizer)
if model is None:
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
watermarking_config = SynthIDTextWatermarkingConfig(**watermark_config_dict)
prompts = ["Write a essay on cats."]
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
).to(DEVICE)
_, inputs_len = inputs["input_ids"].shape
outputs = model.generate(
**inputs,
watermarking_config=watermarking_config,
do_sample=True,
max_length=inputs_len + generation_length,
temperature=temperature,
top_k=40,
top_p=1.0,
)
outputs = outputs[:, inputs_len:]
result = synthid_text_detector(outputs)
# You should set this based on expected fpr (false positive rate) and tpr (true positive rate).
# Check our demo at HF Spaces for more info.
upper_threshold = 0.95
lower_threshold = 0.12
if result[0][0] > upper_threshold:
print("The text is watermarked.")
elif lower_threshold < result[0][0] < upper_threshold:
print("It is hard to determine if the text is watermarked or not.")
else:
print("The text is not watermarked.")