mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Script for distilling zero-shot classifier to more efficient student (#10244)
* add zero-shot distillation script * readme wordsmithing * clean up code * add multi-gpu teacher inference plus tidying up more code * add use_fast_tokenizer arg * update results in readme * more readme wordsmithing * style * Add handle to readme Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * fix code block * add error+docs about distributed & tpu * add @sgugger format requests * xla -> tpu * support fp16 for teacher preds * no checkpoint by default * add demo colab link * add model sharing prompt + model link * correct resulting acc of example Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
97e688bc22
commit
c6fe17557e
155
examples/research_projects/zero-shot-distillation/README.md
Normal file
155
examples/research_projects/zero-shot-distillation/README.md
Normal file
@ -0,0 +1,155 @@
|
||||
# Zero-shot classifier distillation
|
||||
|
||||
Author: @joeddav
|
||||
|
||||
This script provides a way to improve the speed and memory performance of a zero-shot classifier by training a more
|
||||
efficient student model from the zero-shot teacher's predictions over an unlabeled dataset.
|
||||
|
||||
The zero-shot classification pipeline uses a model pre-trained on natural language inference (NLI) to determine the
|
||||
compatibility of a set of candidate class names with a given sequence. This serves as a convenient out-of-the-box
|
||||
classifier without the need for labeled training data. However, for a given sequence, the method requires each
|
||||
possible label to be fed through the large NLI model separately. Thus for `N` sequences and `K` classes, a total of
|
||||
`N*K` forward passes through the model are required. This requirement slows inference considerably, particularly as
|
||||
`K` grows.
|
||||
|
||||
Given (1) an unlabeled corpus and (2) a set of candidate class names, the provided script trains a student model
|
||||
with a standard classification head with `K` output dimensions. The resulting student model can then be used for
|
||||
classifying novel text instances with a significant boost in speed and memory performance while retaining similar
|
||||
classification performance to the original zero-shot model
|
||||
|
||||
### Usage
|
||||
|
||||
A teacher NLI model can be distilled to a more efficient student model by running `distill_classifier.py`:
|
||||
|
||||
```
|
||||
python distill_classifier.py \
|
||||
--data_file <unlabeled_data.txt> \
|
||||
--class_names_file <class_names.txt> \
|
||||
--output_dir <output_dir>
|
||||
```
|
||||
|
||||
`<unlabeled_data.txt>` should be a text file with a single unlabeled example per line. `<class_names.txt>` is a text file with one class name per line.
|
||||
|
||||
Other optional arguments include:
|
||||
|
||||
- `--teacher_name_or_path` (default: `roberta-large-mnli`): The name or path of the NLI teacher model.
|
||||
- `--student_name_or_path` (default: `distillbert-base-uncased`): The name or path of the student model which will
|
||||
be fine-tuned to copy the teacher predictions.
|
||||
- `--hypothesis_template` (default `"This example is {}."`): The template used to turn each label into an NLI-style
|
||||
hypothesis when generating teacher predictions. This template must include a `{}` or similar syntax for the
|
||||
candidate label to be inserted into the template. For example, the default template is `"This example is {}."` With
|
||||
the candidate label `sports`, this would be fed into the model like `[CLS] sequence to classify [SEP] This example
|
||||
is sports . [SEP]`.
|
||||
- `--multi_class`: Whether or not multiple candidate labels can be true. By default, the scores are normalized such
|
||||
that the sum of the label likelihoods for each sequence is 1. If `--multi_class` is passed, the labels are
|
||||
considered independent and probabilities are normalized for each candidate by doing a softmax of the entailment
|
||||
score vs. the contradiction score. This is sometimes called "multi-class multi-label" classification.
|
||||
- `--temperature` (default: `1.0`): The temperature applied to the softmax of the teacher model predictions. A
|
||||
higher temperature results in a student with smoother (lower confidence) predictions than the teacher while a value
|
||||
`<1` resultings in a higher-confidence, peaked distribution. The default `1.0` is equivalent to no smoothing.
|
||||
- `--teacher_batch_size` (default: `32`): The batch size used for generating a single set of teacher predictions.
|
||||
Does not affect training. Use `--per_device_train_batch_size` to change the training batch size.
|
||||
|
||||
Any of the arguments in the 🤗 Trainer's
|
||||
[`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html?#trainingarguments) can also be
|
||||
modified, such as `--learning_rate`, `--fp16`, `--no_cuda`, `--warmup_steps`, etc. Run `python distill_classifier.py
|
||||
-h` for a full list of available arguments or consult the [Trainer
|
||||
documentation](https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments).
|
||||
|
||||
> **Note**: Distributed and TPU training are not currently supported. Single-node multi-GPU is supported, however,
|
||||
and will run automatically if multiple GPUs are available.
|
||||
|
||||
### Example: Topic classification
|
||||
|
||||
> A full colab demo notebook of this example can be found [here](https://colab.research.google.com/drive/1mjBjd0cR8G57ZpsnFCS3ngGyo5nCa9ya?usp=sharing).
|
||||
|
||||
Let's say we're interested in classifying news articles into one of four topic categories: "the world", "sports",
|
||||
"business", or "science/tech". We have an unlabeled dataset, [AG's News](https://huggingface.co/datasets/ag_news),
|
||||
which corresponds to this problem (in reality AG's News is annotated, but we will pretend it is not for the sake of
|
||||
example).
|
||||
|
||||
We can use an NLI model like `roberta-large-mnli` for zero-shot classification like so:
|
||||
|
||||
```python
|
||||
>>> class_names = ["the world", "sports", "business", "science/tech"]
|
||||
>>> hypothesis_template = "This text is about {}."
|
||||
>>> sequence = "A new moon has been discovered in Jupiter's orbit"
|
||||
|
||||
>>> zero_shot_classifier = pipeline("zero-shot-classification", model="roberta-large-mnli")
|
||||
>>> zero_shot_classifier(sequence, class_names, hypothesis_template=hypothesis_template)
|
||||
{'sequence': "A new moon has been discovered in Jupiter's orbit",
|
||||
'labels': ['science/tech', 'the world', 'business', 'sports'],
|
||||
'scores': [0.7035840153694153, 0.18744826316833496, 0.06027870625257492, 0.04868902638554573]}
|
||||
```
|
||||
|
||||
Unfortunately, inference is slow since each of our 4 class names must be fed through the large model for every
|
||||
sequence to be classified. But with our unlabeled data we can distill the model to a small distilbert classifier to
|
||||
make future inference much faster.
|
||||
|
||||
To run the script, we will need to put each training example (text only) from AG's News on its own line in
|
||||
`agnews/train_unlabeled.txt`, and each of the four class names in the newline-separated `agnews/class_names.txt`.
|
||||
Then we can run distillation with the following command:
|
||||
|
||||
```bash
|
||||
python distill_classifier.py \
|
||||
--data_file ./agnews/unlabeled.txt \
|
||||
--class_names_files ./agnews/class_names.txt \
|
||||
--teacher_name_or_path roberta-large-mnli \
|
||||
--hypothesis_template "This text is about {}." \
|
||||
--output_dir ./agnews/distilled
|
||||
```
|
||||
|
||||
The script will generate a set of soft zero-shot predictions from `roberta-large-mnli` for each example in
|
||||
`agnews/unlabeled.txt`. It will then train a student distilbert classifier on the teacher predictions and
|
||||
save the resulting model in `./agnews/distilled`.
|
||||
|
||||
The resulting model can then be loaded and used like any other pre-trained classifier:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
model = AutoModelForSequenceClassification.from_pretrained("./agnews/distilled")
|
||||
tokenizer = AutoTokenizer.from_pretrained("./agnews/distilled")
|
||||
```
|
||||
|
||||
and even used trivially with a `TextClassificationPipeline`:
|
||||
|
||||
```python
|
||||
>>> distilled_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
|
||||
>>> distilled_classifier(sequence)
|
||||
[[{'label': 'the world', 'score': 0.14899294078350067},
|
||||
{'label': 'sports', 'score': 0.03205857425928116},
|
||||
{'label': 'business', 'score': 0.05943061783909798},
|
||||
{'label': 'science/tech', 'score': 0.7595179080963135}]]
|
||||
```
|
||||
|
||||
> Tip: pass `device=0` when constructing a pipeline to run on a GPU
|
||||
|
||||
As we can see, the results of the student closely resemble that of the trainer despite never having seen this
|
||||
example during training. Now let's do a quick & dirty speed comparison simulating 16K examples with a batch size of
|
||||
16:
|
||||
|
||||
```python
|
||||
for _ in range(1000):
|
||||
zero_shot_classifier([sequence] * 16, class_names)
|
||||
# runs in 1m 23s on a single V100 GPU
|
||||
```
|
||||
|
||||
```python
|
||||
%%time
|
||||
for _ in range(1000):
|
||||
distilled_classifier([sequence] * 16)
|
||||
# runs in 10.3s on a single V100 GPU
|
||||
```
|
||||
|
||||
As we can see, the distilled student model runs an order of magnitude faster than its teacher NLI model. This is
|
||||
also a seeting where we only have `K=4` possible labels. The higher the number of classes for a given task, the more
|
||||
drastic the speedup will be, since the zero-shot teacher's complexity scales linearly with the number of classes.
|
||||
|
||||
Since we secretly have access to ground truth labels for AG's news, we can evaluate the accuracy of each model. The
|
||||
original zero-shot model `roberta-large-mnli` gets an accuracy of 69.3% on the held-out test set. After training a
|
||||
student on the unlabeled training set, the distilled model gets a similar score of 70.4%.
|
||||
|
||||
Lastly, you can share the distilled model with the community and/or use it with our inference API by [uploading it
|
||||
to the 🤗 Hub](https://huggingface.co/transformers/model_sharing.html). We've uploaded the distilled model from this
|
||||
example at
|
||||
[joeddav/distilbert-base-uncased-agnews-student](https://huggingface.co/joeddav/distilbert-base-uncased-agnews-student).
|
@ -0,0 +1,338 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
utils,
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||
|
||||
|
||||
DESCRIPTION = """
|
||||
Distills an NLI-based zero-shot classifier to a smaller, more efficient model with a fixed set of candidate class
|
||||
names. Useful for speeding up zero-shot classification in cases where labeled training data is not available, but
|
||||
when only a single fixed set of classes is needed. Takes a teacher NLI model, student classifier model, unlabeled
|
||||
dataset, and set of K possible class names. Yields a single classifier with K outputs corresponding to the provided
|
||||
class names.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeacherModelArguments:
|
||||
teacher_name_or_path: Optional[str] = field(
|
||||
default="roberta-large-mnli", metadata={"help": "The NLI/zero-shot teacher model to be distilled."}
|
||||
)
|
||||
hypothesis_template: Optional[str] = field(
|
||||
default="This example is {}.",
|
||||
metadata={
|
||||
"help": (
|
||||
"Template used to turn class names into mock hypotheses for teacher NLI model. Must include {{}}"
|
||||
"where class name is inserted."
|
||||
)
|
||||
},
|
||||
)
|
||||
teacher_batch_size: Optional[int] = field(
|
||||
default=32, metadata={"help": "Batch size for generating teacher predictions."}
|
||||
)
|
||||
multi_class: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Allow multiple classes to be true rather than forcing them to sum to 1 (sometimes called"
|
||||
"multi-class multi-label classification)."
|
||||
)
|
||||
},
|
||||
)
|
||||
temperature: Optional[float] = field(
|
||||
default=1.0, metadata={"help": "Temperature applied to teacher softmax for distillation."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StudentModelArguments:
|
||||
student_name_or_path: Optional[str] = field(
|
||||
default="distilbert-base-uncased", metadata={"help": "The NLI/zero-shot teacher model to be distilled."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
data_file: str = field(metadata={"help": "Text file with one unlabeled instance per line."})
|
||||
class_names_file: str = field(metadata={"help": "Text file with one class name per line."})
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the Rust tokenizers library) or not."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistillTrainingArguments(TrainingArguments):
|
||||
output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||
)
|
||||
per_device_train_batch_size: int = field(
|
||||
default=32, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||
)
|
||||
per_device_eval_batch_size: int = field(
|
||||
default=128, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||
)
|
||||
num_train_epochs: float = field(default=1.0, metadata={"help": "Total number of training epochs to perform."})
|
||||
do_train: bool = field(default=True, metadata={"help": "Whether to run training of student model."})
|
||||
do_eval: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to evaluate the agreement of the final student predictions and the teacher predictions"
|
||||
"after training."
|
||||
)
|
||||
},
|
||||
)
|
||||
save_total_limit: Optional[int] = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": (
|
||||
"Limit the total amount of checkpoints."
|
||||
"Deletes the older checkpoints in the output_dir. Default is 0 (no checkpoints)."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class DistillationTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
target_p = inputs["labels"]
|
||||
outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
||||
logits = outputs[0]
|
||||
|
||||
loss = -torch.sum(target_p * logits.log_softmax(dim=-1), axis=-1).mean()
|
||||
|
||||
if return_outputs:
|
||||
return loss, outputs
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def read_lines(path):
|
||||
lines = []
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if len(line) > 0:
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def get_premise_hypothesis_pairs(examples, class_names, hypothesis_template):
|
||||
premises = []
|
||||
hypotheses = []
|
||||
for example in examples:
|
||||
for name in class_names:
|
||||
premises.append(example)
|
||||
hypotheses.append(hypothesis_template.format(name))
|
||||
return premises, hypotheses
|
||||
|
||||
|
||||
def get_entailment_id(config):
|
||||
for label, ind in config.label2id.items():
|
||||
if label.lower().startswith("entail"):
|
||||
return ind
|
||||
logging.warning("Could not identify entailment dimension from teacher config label2id. Setting to -1.")
|
||||
return -1
|
||||
|
||||
|
||||
def get_teacher_predictions(
|
||||
model_path: str,
|
||||
examples: List[str],
|
||||
class_names: List[str],
|
||||
hypothesis_template: str,
|
||||
batch_size: int,
|
||||
temperature: float,
|
||||
multi_class: bool,
|
||||
use_fast_tokenizer: bool,
|
||||
no_cuda: bool,
|
||||
fp16: bool,
|
||||
):
|
||||
"""
|
||||
Gets predictions by the same method as the zero-shot pipeline but with DataParallel & more efficient batching
|
||||
"""
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||
model_config = model.config
|
||||
if not no_cuda and torch.cuda.is_available():
|
||||
model = nn.DataParallel(model)
|
||||
batch_size *= len(model.device_ids)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer)
|
||||
|
||||
premises, hypotheses = get_premise_hypothesis_pairs(examples, class_names, hypothesis_template)
|
||||
logits = []
|
||||
|
||||
for i in tqdm(range(0, len(premises), batch_size)):
|
||||
batch_premises = premises[i : i + batch_size]
|
||||
batch_hypotheses = hypotheses[i : i + batch_size]
|
||||
|
||||
encodings = tokenizer(
|
||||
batch_premises,
|
||||
batch_hypotheses,
|
||||
padding=True,
|
||||
truncation="only_first",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=fp16):
|
||||
with torch.no_grad():
|
||||
outputs = model(**encodings)
|
||||
logits.append(outputs.logits.detach().cpu().float())
|
||||
|
||||
entail_id = get_entailment_id(model_config)
|
||||
contr_id = -1 if entail_id == 0 else 0
|
||||
logits = torch.cat(logits, dim=0) # N*K x 3
|
||||
nli_logits = logits.reshape(len(examples), len(class_names), -1)[..., [contr_id, entail_id]] # N x K x 2
|
||||
|
||||
if multi_class:
|
||||
# softmax over (contr, entail) logits for each class independently
|
||||
nli_prob = (nli_logits / temperature).softmax(-1)
|
||||
else:
|
||||
# softmax over entail logits across classes s.t. class probabilities sum to 1.
|
||||
nli_prob = (nli_logits / temperature).softmax(1)
|
||||
|
||||
return nli_prob[..., 1] # N x K
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser(
|
||||
(DataTrainingArguments, TeacherModelArguments, StudentModelArguments, DistillTrainingArguments),
|
||||
description=DESCRIPTION,
|
||||
)
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
data_args, teacher_args, student_args, training_args = parser.parse_json_file(
|
||||
json_file=os.path.abspath(sys.argv[1])
|
||||
)
|
||||
else:
|
||||
data_args, teacher_args, student_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
utils.logging.set_verbosity_info()
|
||||
utils.logging.enable_default_handler()
|
||||
utils.logging.enable_explicit_format()
|
||||
|
||||
if training_args.local_rank != -1:
|
||||
raise ValueError("Distributed training is not currently supported.")
|
||||
if training_args.tpu_num_cores is not None:
|
||||
raise ValueError("TPU acceleration is not currently supported.")
|
||||
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# 1. read in data
|
||||
examples = read_lines(data_args.data_file)
|
||||
class_names = read_lines(data_args.class_names_file)
|
||||
|
||||
# 2. get teacher predictions and load into dataset
|
||||
logger.info("Generating predictions from zero-shot teacher model")
|
||||
teacher_soft_preds = get_teacher_predictions(
|
||||
teacher_args.teacher_name_or_path,
|
||||
examples,
|
||||
class_names,
|
||||
teacher_args.hypothesis_template,
|
||||
teacher_args.teacher_batch_size,
|
||||
teacher_args.temperature,
|
||||
teacher_args.multi_class,
|
||||
data_args.use_fast_tokenizer,
|
||||
training_args.no_cuda,
|
||||
training_args.fp16,
|
||||
)
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"text": examples,
|
||||
"labels": teacher_soft_preds,
|
||||
}
|
||||
)
|
||||
|
||||
# 3. create student
|
||||
logger.info("Initializing student model")
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
student_args.student_name_or_path, num_labels=len(class_names)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(student_args.student_name_or_path, use_fast=data_args.use_fast_tokenizer)
|
||||
model.config.id2label = {i: label for i, label in enumerate(class_names)}
|
||||
model.config.label2id = {label: i for i, label in enumerate(class_names)}
|
||||
|
||||
# 4. train student on teacher predictions
|
||||
dataset = dataset.map(tokenizer, input_columns="text")
|
||||
dataset.set_format("torch")
|
||||
|
||||
def compute_metrics(p, return_outputs=False):
|
||||
preds = p.predictions.argmax(-1)
|
||||
proxy_labels = p.label_ids.argmax(-1) # "label_ids" are actually distributions
|
||||
return {"agreement": (preds == proxy_labels).mean().item()}
|
||||
|
||||
trainer = DistillationTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
logger.info("Training student model on teacher predictions")
|
||||
trainer.train()
|
||||
|
||||
if training_args.do_eval:
|
||||
agreement = trainer.evaluate(eval_dataset=dataset)["eval_agreement"]
|
||||
logger.info(f"Agreement of student and teacher predictions: {agreement * 100:0.2f}%")
|
||||
|
||||
trainer.save_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user