add class wireframes for Bert decoder

This commit is contained in:
Rémi Louf 2019-10-07 16:43:21 +02:00
parent dda1adad6d
commit 31adbb247c

View File

@ -331,6 +331,14 @@ class BertEncoderLayer(nn.Module):
return outputs
class BertDecoderLayer(nn.Module):
def __init__(self, config):
raise NotImplementedError
def forward(self, hidden_state, encoder_output):
raise NotImplementedError
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
@ -363,6 +371,14 @@ class BertEncoder(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertDecoder(nn.Module):
def __init__(self, config):
raise NotImplementedError
def forward(self, encoder_output):
raise NotImplementedError
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()