update docstrings

This commit is contained in:
thomwolf 2019-02-07 23:15:20 +01:00
parent e77721e4fe
commit eb8fda51f4

View File

@ -984,7 +984,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[
`mems`: optional memomry of hidden states from previous forward passes
as a list (num layers) of hidden states at the entry of each layer
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
Outputs:
A tuple of (last_hidden_state, new_mems)
`last_hidden_state`: the encoded-hidden-states at the top of the model
@ -1220,6 +1222,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
def forward(self, input_ids, mems=None):
""" Params:
input_ids :: [len, bsz]
mems :: optional mems from previous forwar passes (or init_mems)
list (num layers) of mem states at the entry of each layer
shape :: [self.config.mem_len, bsz, self.config.d_model]
Returns:
tuple (last_hidden, new_mems) where:
new_mems: list (num layers) of mem states at the entry of each layer
@ -1250,8 +1255,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[
`target`: a torch.LongTensor of shape [sequence_length, batch_size]
`target`: an optional torch.LongTensor of shape [sequence_length, batch_size]
with the target token indices selected in the range [0, self.config.n_token[
`mems`: an optional memory of hidden states from previous forward passes
as a list (num layers) of hidden states at the entry of each layer
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
Outputs:
A tuple of (last_hidden_state, new_mems)