mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Delete untested, broken Model2LSTM (#2968)
This commit is contained in:
parent
0e84559d64
commit
129f0604ac
@ -18,7 +18,6 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
@ -294,21 +293,3 @@ class Model2Model(PreTrainedEncoderDecoder):
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Model2LSTM(PreTrainedEncoderDecoder):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
if kwargs.get("decoder_model", None) is None:
|
||||
# We will create a randomly initilized LSTM model as decoder
|
||||
if "decoder_config" not in kwargs:
|
||||
raise ValueError(
|
||||
"To load an LSTM in Encoder-Decoder model, please supply either: "
|
||||
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
|
||||
" - a dictionary of configuration parameters that will be used to initialize a"
|
||||
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
||||
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`"
|
||||
)
|
||||
kwargs["decoder_model"] = torch.nn.LSTM(kwargs.pop("decoder_config"))
|
||||
model = super().from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user