|
1 | 1 | import collections
|
| 2 | +import itertools |
2 | 3 | import math
|
3 | 4 |
|
| 5 | +import tensorflow as tf |
4 | 6 | from tensorflow.python.feature_column import feature_column as fc_old
|
5 | 7 | from tensorflow.python.feature_column import feature_column_v2 as fc_lib
|
6 | 8 | from tensorflow.python.framework import tensor_shape
|
@@ -185,3 +187,135 @@ def reset(self):
|
185 | 187 | @property
|
186 | 188 | def embedding_and_ids(self):
|
187 | 189 | return self._embedding_delegate.embedding_and_ids
|
| 190 | + |
| 191 | + |
| 192 | +def concat_column(categorical_columns): |
| 193 | + if not isinstance(categorical_columns, list): |
| 194 | + raise ValueError("categorical_columns should be a list") |
| 195 | + |
| 196 | + if not categorical_columns: |
| 197 | + raise ValueError("categorical_columns shouldn't be empty") |
| 198 | + |
| 199 | + for column in categorical_columns: |
| 200 | + if not isinstance(column, fc_lib.CategoricalColumn): |
| 201 | + raise ValueError( |
| 202 | + "Items of categorical_columns should be CategoricalColumn." |
| 203 | + " Given:{}".format(column) |
| 204 | + ) |
| 205 | + |
| 206 | + return ConcatColumn(categorical_columns=tuple(categorical_columns)) |
| 207 | + |
| 208 | + |
| 209 | +class ConcatColumn( |
| 210 | + fc_lib.CategoricalColumn, |
| 211 | + fc_old._CategoricalColumn, |
| 212 | + collections.namedtuple("ConcatColumn", ("categorical_columns")), |
| 213 | +): |
| 214 | + def __init__(self, **kwargs): |
| 215 | + # Calculate the offset tensor |
| 216 | + total_num_buckets = 0 |
| 217 | + leaf_column_num_buckets = [] |
| 218 | + for categorical_column in self.categorical_columns: |
| 219 | + leaf_column_num_buckets.append(categorical_column.num_buckets) |
| 220 | + total_num_buckets += categorical_column.num_buckets |
| 221 | + self.accumulated_offsets = list( |
| 222 | + itertools.accumulate([0] + leaf_column_num_buckets[:-1]) |
| 223 | + ) |
| 224 | + self.total_num_buckets = total_num_buckets |
| 225 | + |
| 226 | + @property |
| 227 | + def _is_v2_column(self): |
| 228 | + for categorical_column in self.categorical_columns: |
| 229 | + if not categorical_column._is_v2_column: |
| 230 | + return False |
| 231 | + |
| 232 | + return True |
| 233 | + |
| 234 | + @property |
| 235 | + def name(self): |
| 236 | + feature_names = [] |
| 237 | + for categorical_column in self.categorical_columns: |
| 238 | + feature_names.append(categorical_column.name) |
| 239 | + |
| 240 | + return "_C_".join(sorted(feature_names)) |
| 241 | + |
| 242 | + @property |
| 243 | + def num_buckets(self): |
| 244 | + return self.total_num_buckets |
| 245 | + |
| 246 | + @property |
| 247 | + def _num_buckets(self): |
| 248 | + return self.total_num_buckets |
| 249 | + |
| 250 | + def transform_feature(self, transformation_cache, state_manager): |
| 251 | + feature_tensors = [] |
| 252 | + for categorical_column in self.categorical_columns: |
| 253 | + ids_and_weights = categorical_column.get_sparse_tensors( |
| 254 | + transformation_cache, state_manager |
| 255 | + ) |
| 256 | + feature_tensors.append(ids_and_weights.id_tensor) |
| 257 | + |
| 258 | + feature_tensors_with_offset = [] |
| 259 | + for index, offset in enumerate(self.accumulated_offsets): |
| 260 | + feature_tensor = feature_tensors[index] |
| 261 | + feature_tensor_with_offset = tf.SparseTensor( |
| 262 | + indices=feature_tensor.indices, |
| 263 | + values=tf.cast( |
| 264 | + tf.add(feature_tensor.values, offset), tf.int64 |
| 265 | + ), |
| 266 | + dense_shape=feature_tensor.dense_shape, |
| 267 | + ) |
| 268 | + feature_tensors_with_offset.append(feature_tensor_with_offset) |
| 269 | + |
| 270 | + return tf.sparse.concat(axis=-1, sp_inputs=feature_tensors_with_offset) |
| 271 | + |
| 272 | + def get_sparse_tensors(self, transformation_cache, state_manager): |
| 273 | + return fc_lib.CategoricalColumn.IdWeightPair( |
| 274 | + transformation_cache.get(self, state_manager), None |
| 275 | + ) |
| 276 | + |
| 277 | + @property |
| 278 | + def parents(self): |
| 279 | + return list(self.categorical_columns) |
| 280 | + |
| 281 | + @property |
| 282 | + def parse_example_spec(self): |
| 283 | + config = {} |
| 284 | + for categorical_column in self.categorical_columns: |
| 285 | + config.update(categorical_column.parse_example_spec) |
| 286 | + |
| 287 | + return config |
| 288 | + |
| 289 | + @property |
| 290 | + def _parse_example_spec(self): |
| 291 | + return self.parse_example_spec |
| 292 | + |
| 293 | + def get_config(self): |
| 294 | + from tensorflow.python.feature_column.serialization import ( |
| 295 | + serialize_feature_column, |
| 296 | + ) # pylint: disable=g-import-not-at-top |
| 297 | + |
| 298 | + config = dict(zip(self._fields, self)) |
| 299 | + config["categorical_columns"] = tuple( |
| 300 | + [serialize_feature_column(fc) for fc in self.categorical_columns] |
| 301 | + ) |
| 302 | + |
| 303 | + return config |
| 304 | + |
| 305 | + @classmethod |
| 306 | + def from_config(cls, config, custom_objects=None, columns_by_name=None): |
| 307 | + """See 'FeatureColumn` base class.""" |
| 308 | + from tensorflow.python.feature_column.serialization import ( |
| 309 | + deserialize_feature_column, |
| 310 | + ) # pylint: disable=g-import-not-at-top |
| 311 | + |
| 312 | + fc_lib._check_config_keys(config, cls._fields) |
| 313 | + kwargs = fc_lib._standardize_and_copy_config(config) |
| 314 | + kwargs["categorical_columns"] = tuple( |
| 315 | + [ |
| 316 | + deserialize_feature_column(c, custom_objects, columns_by_name) |
| 317 | + for c in config["categorical_columns"] |
| 318 | + ] |
| 319 | + ) |
| 320 | + |
| 321 | + return cls(**kwargs) |
0 commit comments