Skip to content

Commit b884618

Browse files
committed
Remove @tf.function wrappers
1 parent a519d04 commit b884618

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

tensorflow_addons/text/crf.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import numpy as np
2121
import tensorflow as tf
2222

23+
# TODO: Wrap functions in @tf.function once
24+
# https://github.com/tensorflow/tensorflow/issues/29075 is resolved
2325

24-
@tf.function
2526
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
2627
transition_params):
2728
"""Computes the unnormalized score for a tag sequence.
@@ -66,7 +67,6 @@ def _multi_seq_fn():
6667
return _multi_seq_fn()
6768

6869

69-
@tf.function
7070
def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
7171
transition_params):
7272
"""Computes the unnormalized score of all tag sequences matching
@@ -115,7 +115,6 @@ def _multi_seq_fn():
115115
return _multi_seq_fn()
116116

117117

118-
@tf.function
119118
def crf_log_norm(inputs, sequence_lengths, transition_params):
120119
"""Computes the normalization for a CRF.
121120
@@ -163,7 +162,6 @@ def _multi_seq_fn():
163162
return _multi_seq_fn()
164163

165164

166-
@tf.function
167165
def crf_log_likelihood(inputs,
168166
tag_indices,
169167
sequence_lengths,
@@ -201,7 +199,6 @@ def crf_log_likelihood(inputs,
201199
return log_likelihood, transition_params
202200

203201

204-
@tf.function
205202
def crf_unary_score(tag_indices, sequence_lengths, inputs):
206203
"""Computes the unary scores of tag sequences.
207204
@@ -236,7 +233,6 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
236233
return unary_scores
237234

238235

239-
@tf.function
240236
def crf_binary_score(tag_indices, sequence_lengths, transition_params):
241237
"""Computes the binary scores of tag sequences.
242238
@@ -272,7 +268,6 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
272268
return binary_scores
273269

274270

275-
@tf.function
276271
def crf_forward(inputs, state, transition_params, sequence_lengths):
277272
"""Computes the alpha values in a linear-chain CRF.
278273
@@ -364,7 +359,6 @@ def call(self, inputs, state):
364359
return backpointers, new_state
365360

366361

367-
@tf.function
368362
def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
369363
mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
370364
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
@@ -373,7 +367,6 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
373367
return crf_fwd_layer(inputs, state, mask=mask)
374368

375369

376-
@tf.function
377370
def crf_decode_backward(inputs, state):
378371
"""Computes backward decoding in a linear-chain CRF."""
379372
inputs = tf.transpose(inputs, [1, 0, 2])
@@ -387,7 +380,6 @@ def _scan_fn(state, inputs):
387380
return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
388381

389382

390-
@tf.function
391383
def crf_decode(potentials, transition_params, sequence_length):
392384
"""Decode the highest scoring sequence of tags in TensorFlow.
393385

0 commit comments

Comments
 (0)