mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add class wireframes for Bert decoder
This commit is contained in:
parent
dda1adad6d
commit
31adbb247c
@ -331,6 +331,14 @@ class BertEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class BertDecoderLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, hidden_state, encoder_output):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertEncoder, self).__init__()
|
||||
@ -363,6 +371,14 @@ class BertEncoder(nn.Module):
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
class BertDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, encoder_output):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPooler, self).__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user