@@ -272,6 +272,19 @@ def crf_forward(inputs, state, transition_params, sequence_lengths):
272
272
"""Computes the alpha values in a linear-chain CRF.
273
273
274
274
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
275
+
276
+ Args:
277
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
278
+ state: A [batch_size, num_tags] matrix containing the previous alpha
279
+ values.
280
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
281
+ This matrix is expanded into a [1, num_tags, num_tags] in preparation
282
+ for the broadcast summation occurring within the cell.
283
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
284
+
285
+ Returns:
286
+ new_alphas: A [batch_size, num_tags] matrix containing the
287
+ new alpha values.
275
288
"""
276
289
277
290
sequence_lengths = tf .maximum (
@@ -351,6 +364,17 @@ def build(self, input_shape):
351
364
super (CrfDecodeForwardRnnCell , self ).build (input_shape )
352
365
353
366
def call (self , inputs , state ):
367
+ """Build the CrfDecodeForwardRnnCell.
368
+
369
+ Args:
370
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
371
+ state: A [batch_size, num_tags] matrix containing the previous step's
372
+ score values.
373
+
374
+ Returns:
375
+ backpointers: A [batch_size, num_tags] matrix of backpointers.
376
+ new_state: A [batch_size, num_tags] matrix of new score values.
377
+ """
354
378
state = tf .expand_dims (state [0 ], 2 )
355
379
transition_scores = state + self ._transition_params
356
380
new_state = inputs + tf .reduce_max (transition_scores , [1 ])
@@ -360,6 +384,19 @@ def call(self, inputs, state):
360
384
361
385
362
386
def crf_decode_forward (inputs , state , transition_params , sequence_lengths ):
387
+ """Computes forward decoding in a linear-chain CRF.
388
+
389
+ Args:
390
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
391
+ state: A [batch_size, num_tags] matrix containing the previous step's
392
+ score values.
393
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
394
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
395
+
396
+ Returns:
397
+ backpointers: A [batch_size, num_tags] matrix of backpointers.
398
+ new_state: A [batch_size, num_tags] matrix of new score values.
399
+ """
363
400
mask = tf .sequence_mask (sequence_lengths , tf .shape (inputs )[1 ])
364
401
crf_fwd_cell = CrfDecodeForwardRnnCell (transition_params )
365
402
crf_fwd_layer = tf .keras .layers .RNN (
@@ -368,7 +405,17 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
368
405
369
406
370
407
def crf_decode_backward (inputs , state ):
371
- """Computes backward decoding in a linear-chain CRF."""
408
+ """Computes backward decoding in a linear-chain CRF.
409
+
410
+ Args:
411
+ inputs: A [batch_size, num_tags] matrix of
412
+ backpointer of next step (in time order).
413
+ state: A [batch_size, 1] matrix of tag index of next step.
414
+
415
+ Returns:
416
+ new_tags: A [batch_size, num_tags]
417
+ tensor containing the new tag indices.
418
+ """
372
419
inputs = tf .transpose (inputs , [1 , 0 , 2 ])
373
420
374
421
def _scan_fn (state , inputs ):
0 commit comments