Skip to content

Commit d3f1986

Browse files
committed
Add missing docstrings
1 parent b884618 commit d3f1986

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

tensorflow_addons/text/crf.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,19 @@ def crf_forward(inputs, state, transition_params, sequence_lengths):
272272
"""Computes the alpha values in a linear-chain CRF.
273273
274274
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
275+
276+
Args:
277+
inputs: A [batch_size, num_tags] matrix of unary potentials.
278+
state: A [batch_size, num_tags] matrix containing the previous alpha
279+
values.
280+
transition_params: A [num_tags, num_tags] matrix of binary potentials.
281+
This matrix is expanded into a [1, num_tags, num_tags] in preparation
282+
for the broadcast summation occurring within the cell.
283+
sequence_lengths: A [batch_size] vector of true sequence lengths.
284+
285+
Returns:
286+
new_alphas: A [batch_size, num_tags] matrix containing the
287+
new alpha values.
275288
"""
276289

277290
sequence_lengths = tf.maximum(
@@ -351,6 +364,17 @@ def build(self, input_shape):
351364
super(CrfDecodeForwardRnnCell, self).build(input_shape)
352365

353366
def call(self, inputs, state):
367+
"""Build the CrfDecodeForwardRnnCell.
368+
369+
Args:
370+
inputs: A [batch_size, num_tags] matrix of unary potentials.
371+
state: A [batch_size, num_tags] matrix containing the previous step's
372+
score values.
373+
374+
Returns:
375+
backpointers: A [batch_size, num_tags] matrix of backpointers.
376+
new_state: A [batch_size, num_tags] matrix of new score values.
377+
"""
354378
state = tf.expand_dims(state[0], 2)
355379
transition_scores = state + self._transition_params
356380
new_state = inputs + tf.reduce_max(transition_scores, [1])
@@ -360,6 +384,19 @@ def call(self, inputs, state):
360384

361385

362386
def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
387+
"""Computes forward decoding in a linear-chain CRF.
388+
389+
Args:
390+
inputs: A [batch_size, num_tags] matrix of unary potentials.
391+
state: A [batch_size, num_tags] matrix containing the previous step's
392+
score values.
393+
transition_params: A [num_tags, num_tags] matrix of binary potentials.
394+
sequence_lengths: A [batch_size] vector of true sequence lengths.
395+
396+
Returns:
397+
backpointers: A [batch_size, num_tags] matrix of backpointers.
398+
new_state: A [batch_size, num_tags] matrix of new score values.
399+
"""
363400
mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
364401
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
365402
crf_fwd_layer = tf.keras.layers.RNN(
@@ -368,7 +405,17 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
368405

369406

370407
def crf_decode_backward(inputs, state):
371-
"""Computes backward decoding in a linear-chain CRF."""
408+
"""Computes backward decoding in a linear-chain CRF.
409+
410+
Args:
411+
inputs: A [batch_size, num_tags] matrix of
412+
backpointer of next step (in time order).
413+
state: A [batch_size, 1] matrix of tag index of next step.
414+
415+
Returns:
416+
new_tags: A [batch_size, num_tags]
417+
tensor containing the new tag indices.
418+
"""
372419
inputs = tf.transpose(inputs, [1, 0, 2])
373420

374421
def _scan_fn(state, inputs):

0 commit comments

Comments
 (0)