20
20
import numpy as np
21
21
import tensorflow as tf
22
22
23
+ # TODO: Wrap functions in @tf.function once
24
+ # https://github.com/tensorflow/tensorflow/issues/29075 is resolved
23
25
24
- @tf .function
25
26
def crf_sequence_score (inputs , tag_indices , sequence_lengths ,
26
27
transition_params ):
27
28
"""Computes the unnormalized score for a tag sequence.
@@ -66,7 +67,6 @@ def _multi_seq_fn():
66
67
return _multi_seq_fn ()
67
68
68
69
69
- @tf .function
70
70
def crf_multitag_sequence_score (inputs , tag_bitmap , sequence_lengths ,
71
71
transition_params ):
72
72
"""Computes the unnormalized score of all tag sequences matching
@@ -115,7 +115,6 @@ def _multi_seq_fn():
115
115
return _multi_seq_fn ()
116
116
117
117
118
- @tf .function
119
118
def crf_log_norm (inputs , sequence_lengths , transition_params ):
120
119
"""Computes the normalization for a CRF.
121
120
@@ -163,7 +162,6 @@ def _multi_seq_fn():
163
162
return _multi_seq_fn ()
164
163
165
164
166
- @tf .function
167
165
def crf_log_likelihood (inputs ,
168
166
tag_indices ,
169
167
sequence_lengths ,
@@ -201,7 +199,6 @@ def crf_log_likelihood(inputs,
201
199
return log_likelihood , transition_params
202
200
203
201
204
- @tf .function
205
202
def crf_unary_score (tag_indices , sequence_lengths , inputs ):
206
203
"""Computes the unary scores of tag sequences.
207
204
@@ -236,7 +233,6 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
236
233
return unary_scores
237
234
238
235
239
- @tf .function
240
236
def crf_binary_score (tag_indices , sequence_lengths , transition_params ):
241
237
"""Computes the binary scores of tag sequences.
242
238
@@ -272,7 +268,6 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
272
268
return binary_scores
273
269
274
270
275
- @tf .function
276
271
def crf_forward (inputs , state , transition_params , sequence_lengths ):
277
272
"""Computes the alpha values in a linear-chain CRF.
278
273
@@ -364,7 +359,6 @@ def call(self, inputs, state):
364
359
return backpointers , new_state
365
360
366
361
367
- @tf .function
368
362
def crf_decode_forward (inputs , state , transition_params , sequence_lengths ):
369
363
mask = tf .sequence_mask (sequence_lengths , tf .shape (inputs )[1 ])
370
364
crf_fwd_cell = CrfDecodeForwardRnnCell (transition_params )
@@ -373,7 +367,6 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
373
367
return crf_fwd_layer (inputs , state , mask = mask )
374
368
375
369
376
- @tf .function
377
370
def crf_decode_backward (inputs , state ):
378
371
"""Computes backward decoding in a linear-chain CRF."""
379
372
inputs = tf .transpose (inputs , [1 , 0 , 2 ])
@@ -387,7 +380,6 @@ def _scan_fn(state, inputs):
387
380
return tf .transpose (tf .scan (_scan_fn , inputs , state ), [1 , 0 , 2 ])
388
381
389
382
390
- @tf .function
391
383
def crf_decode (potentials , transition_params , sequence_length ):
392
384
"""Decode the highest scoring sequence of tags in TensorFlow.
393
385
0 commit comments