-
Notifications
You must be signed in to change notification settings - Fork 577
RFC: Keras categorical inputs #188
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,357 @@ | ||
# Keras categorical inputs | ||
|
||
| Status | Proposed | | ||
:-------------- |:---------------------------------------------------- | | ||
| **Author(s)** | Zhenyu Tan ([email protected]), Francois Chollet ([email protected])| | ||
| **Sponsor** | Karmel Allison ([email protected]), Martin Wicke ([email protected]) | | ||
| **Updated** | 2019-12-12 | | ||
|
||
## Objective | ||
|
||
This document proposes 4 new preprocessing Keras layers (`CategoryLookup`, `CategoryCrossing`, `CategoryEncoding`, `CategoryHashing`), and 1 additional op (`to_sparse`) to allow users to: | ||
* Perform feature engineering for categorical inputs | ||
* Replace feature columns and `tf.keras.layers.DenseFeatures` with proposed layers | ||
* Introduce sparse inputs that work with Keras linear models and other layers that support sparsity | ||
|
||
Other proposed layers for replacement of feature columns such as `tf.feature_column.bucketized_column` and `tf.feature_column.numeric_column` has been discussed [here](https://github.com/keras-team/governance/blob/master/rfcs/20190502-preprocessing-layers.md) and are not the focus of this document. | ||
|
||
## Example Workflows | ||
|
||
Two example workflows are presented below. These workflows can be found at this [colab](https://colab.sandbox.google.com/drive/1cEJhSYLcc2MKH7itwcDvue4PfvrLN-OR#scrollTo=22sa0D19kxXY). | ||
|
||
### Workflow 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to provide an example workflow where you have to get the vocabulary from e.g. a csv file using tf.data and in particular, tf.data.experimental.unique and tf.data.experimental.get_single_element to read out the tensor? @jsimsa wdyt? This is gonna be very common, i think, in real use cases. |
||
|
||
The first example gives an equivalent code snippet to canned `LinearEstimator` [tutorial](https://www.tensorflow.org/tutorials/estimator/linear) on the Titanic dataset: | ||
|
||
```python | ||
CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck', 'embark_town', 'alone'] | ||
NUMERICAL_COLUMNS = ['age', 'fare'] | ||
# input list to create functional model. | ||
model_inputs = [] | ||
# input list to feed linear model. | ||
linear_inputs = [] | ||
for feature_name in CATEGORICAL_COLUMNS: | ||
feature_input = tf.keras.Input(shape=(1,), dtype=tf.string, name=feature_name, sparse=True) | ||
vocab_list = sorted(dftrain[feature_name].unique()) | ||
# Map string values to indices | ||
x = tf.keras.layers.CategoryLookup(vocabulary=vocab_list, name=feature_name)(feature_input) | ||
x = tf.keras.layers.CategoryEncoding(num_categories=len(vocab_list))(x) | ||
linear_inputs.append(x) | ||
model_inputs.append(feature_input) | ||
|
||
for feature_name in NUMERICAL_COLUMNS: | ||
feature_input = tf.keras.Input(shape=(1,), name=feature_name) | ||
linear_inputs.append(feature_input) | ||
model_inputs.append(feature_input) | ||
|
||
linear_model = tf.keras.experimental.LinearModel(units=1) | ||
linear_logits = linear_model(linear_inputs) | ||
model = tf.keras.Model(model_inputs, linear_logits) | ||
|
||
model.compile('sgd', loss=tf.keras.losses.BinaryCrossEntropy(from_logits=True), metrics=['accuracy']) | ||
|
||
dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that this is referenced before assignment? Does this code run? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix applied. |
||
y_train = dftrain.pop('survived') | ||
|
||
dataset = tf.data.Dataset.from_tensor_slices(( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to have a single input node and use some layer to decompose the input to these single column nodes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use tf.split for that? |
||
(tf.to_sparse(dftrain.sex, "Unknown"), tf.to_sparse(dftrain.n_siblings_spouses, -1), | ||
tf.to_sparse(dftrain.parch, -1), tf.to_sparse(dftrain['class'], "Unknown"), tf.to_sparse(dftrain.deck, "Unknown"), | ||
tf.expand_dims(dftrain.age, axis=1), tf.expand_dims(dftrain.fare, axis=1)), | ||
y_train)).batch(bach_size).repeat(n_epochs) | ||
|
||
model.fit(dataset) | ||
``` | ||
|
||
### Workflow 2 | ||
|
||
The second example gives an instruction on how to transition from categorical feature columns to the proposed layers. Note that one difference for vocab categorical column is that, instead of providing a pair of mutually exclusive `default_value` and `num_oov_buckets` where `default_value` represents the value to map input to given out-of-vocab value, and `num_oov_buckets` represents value range of [len(vocab), len(vocab)+num_oov_buckets) to map input to from a hashing function given out-of-vocab value. In practice, we believe out-of-vocab values should be mapped to the head, i.e., [0, num_oov_tokens), and in-vocab values should be mapped to [num_oov_tokens, num_oov_tokens+len(vocab)). | ||
|
||
1. Categorical vocab list column | ||
|
||
Original: | ||
```python | ||
fc = tf.feature_column.categorical_feature_column_with_vocabulary_list( | ||
key, vocabulary_list, dtype, default_value, num_oov_buckets) | ||
``` | ||
Proposed: | ||
```python | ||
x = tf.keras.Input(shape=(1,), name=key, dtype=dtype) | ||
layer = tf.keras.layers.CategoryLookup( | ||
vocabulary=vocabulary_list, num_oov_tokens=num_oov_buckets) | ||
out = layer(x) | ||
``` | ||
|
||
2. categorical vocab file column | ||
|
||
Original: | ||
```python | ||
fc = tf.feature_column.categorical_column_with_vocab_file( | ||
key, vocabulary_file, vocabulary_size, dtype, | ||
default_value, num_oov_buckets) | ||
``` | ||
Proposed: | ||
```python | ||
x = tf.keras.Input(shape=(1,), name=key, dtype=dtype) | ||
layer = tf.keras.layers.CategoryLookup( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, so here you allow the user to provide a vocabulary file directly, so the tf.data example may not be necessary. May still be useful if users have vocab that needs to be munged a bit before reading directly. But less important. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this layer is more of a complimentary to that, i.e., tf.data can parse records and generate vocab file, of read vocab file and do other processing and still return string tensors. This layer is taken from that and convert things to indices before it gets to embedding. |
||
vocabulary=vocabulary_file, num_oov_tokens=num_oov_buckets) | ||
out = layer(x) | ||
``` | ||
Note: `vocabulary_size` is only valid if `adapt` is called. Otherwise if user desires to lookup for the first K vocabularies in vocab file, then shrink the vocab file by only having the first K lines. | ||
|
||
3. categorical hash column | ||
|
||
Original: | ||
```python | ||
fc = tf.feature_column.categorical_column_with_hash_bucket( | ||
key, hash_bucket_size, dtype) | ||
``` | ||
Proposed: | ||
```python | ||
x = tf.keras.Input(shape=(1,), name=key, dtype=dtype) | ||
layer = tf.keras.layers.CategoryHashing(num_bins=hash_bucket_size) | ||
out = layer(x) | ||
``` | ||
|
||
4. categorical identity column | ||
|
||
Original: | ||
```python | ||
fc = tf.feature_column.categorical_column_with_identity( | ||
key, num_buckets, default_value) | ||
``` | ||
Proposed: | ||
```python | ||
x = tf.keras.Input(shape=(1,), name=key, dtype=dtype) | ||
layer = tf.keras.layers.Lambda(lambda x: tf.where(tf.logical_or(x < 0, x > num_buckets), tf.fill(dims=tf.shape(x), value=default_value), x)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we allowed to specify the hash function, this could also be folded into the CategoryHashing with an IdentityHash. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CategoryHashing does not check oov values, so if we have that then it's complicating the signature. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you not use a layer to the lamda layer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about adding a layer to do the work of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's why we proposed a tf.sparse.from_dense(value, ignore_value), and our philosophy here is tensor type in / tensor type out, i.e., you can specifically convert to sparse tensors before it gets to here, which is pretty useful in TF Transform you all preprocessing things are done in memory efficient way. |
||
out = layer(x) | ||
``` | ||
|
||
5. cross column | ||
|
||
Original: | ||
```python | ||
fc_1 = tf.feature_column.categorical_column_with_vocabulary_list(key_1, vocabulary_list, | ||
dtype, default_value, num_oov_buckets) | ||
fc_2 = tf.feature_column.categorical_column_with_hash_bucket(key_2, hash_bucket_size, | ||
dtype) | ||
fc = tf.feature_column.crossed_column([fc_1, fc_2], hash_bucket_size, hash_key) | ||
``` | ||
Proposed: | ||
```python | ||
x1 = tf.keras.Input(shape=(1,), name=key_1, dtype=dtype) | ||
x2 = tf.keras.Input(shape=(1,), name=key_2, dtype=dtype) | ||
layer1 = tf.keras.layers.CategoryLookup( | ||
vocabulary=vocabulary_list, | ||
num_oov_tokens=num_oov_buckets) | ||
x1 = layer1(x1) | ||
layer2 = tf.keras.layers.CategoryHashing( | ||
num_bins=hash_bucket_size) | ||
x2 = layer2(x2) | ||
layer = tf.keras.layers.CategoryCrossing(num_bins=hash_bucket_size) | ||
out = layer([x1, x2]) | ||
``` | ||
|
||
6. weighted categorical column | ||
|
||
Original: | ||
```python | ||
fc = tf.feature_column.categorical_column_with_vocab_list(key, vocabulary_list, | ||
dtype, default_value, num_oov_buckets) | ||
weight_fc = tf.feature_column.weighted_categorical_column(fc, weight_feature_key, | ||
dtype=weight_dtype) | ||
linear_model = tf.estimator.LinearClassifier(units, feature_columns=[weight_fc]) | ||
``` | ||
Proposed: | ||
```python | ||
x1 = tf.keras.Input(shape=(1,), name=key, dtype=dtype) | ||
x2 = tf.keras.Input(shape=(1,), name=weight_feature_key, dtype=weight_dtype) | ||
layer = tf.keras.layers.CategoryLookup( | ||
vocabulary=vocabulary_list, | ||
num_oov_tokens=num_oov_buckets) | ||
x1 = layer(x1) | ||
x = tf.keras.layers.CategoryEncoding(num_categories=len(vocabulary_list)+num_oov_buckets)([x1, x2]) | ||
linear_model = tf.keras.premade.LinearModel(units) | ||
linear_logits = linear_model(x) | ||
``` | ||
|
||
## Pain Points | ||
|
||
Specifically, by introducing the 4 layers, we aim to address these pain points: | ||
* Users have to define both feature columns and Keras Inputs for the model, resulting in code duplication and deviation from DRY (Do not repeat yourself) principle. See this [Github issue](https://github.com/tensorflow/tensorflow/issues/27416). | ||
* Users with large dimension categorical inputs will incur large memory footprint and computation cost, if wrapped with indicator column through `tf.keras.layers.DenseFeatures`. | ||
* Currently there is no way to correctly feed Keras linear model or dense layer with multivalent categorical inputs or weighted categorical inputs. | ||
|
||
## Design Proposal | ||
We propose a CategoryLookup layer to replace `tf.feature_column.categorical_column_with_vocabulary_list` and `tf.feature_column.categorical_column_with_vocabulary_file`, a `CategoryHashing` layer to replace `tf.feature_column.categorical_column_with_hash_bucket`, a `CategoryCrossing` layer to replace `tf.feature_column.crossed_column`, and another `CategoryEncoding` layer to convert the sparse input to the format required by linear models. | ||
|
||
```python | ||
`tf.keras.layers.CategoryLookup` | ||
CategoryLookup(PreprocessingLayer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don't forget to implement the correct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the reminder! |
||
"""This layer transforms categorical inputs to index space. | ||
If input is dense/sparse, then output is dense/sparse.""" | ||
|
||
def __init__(self, max_tokens=None, num_oov_tokens=1, vocabulary=None, | ||
name=None, **kwargs): | ||
"""Constructs a CategoryLookup layer. | ||
|
||
Args: | ||
max_tokens: The maximum size of the vocabulary for this layer. If None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before executing the training loop, we will call preprocessLayers.adapt to calculate the statistical value. If the dataset is very huge, how does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This really depends on what the execution engine is, and not part of the layer's responsibility, but instead should be responsibility of ProcessingStage. |
||
there is no cap on the size of the vocabulary. This is used when `adapt` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, the statistical process will be executed in a pre-process stage before model training just as the stage where TF-Transform stands in the End-to-End pipeline. How do we pass the statistical result from preprocess stage to the model training stage using keras preprocess input layer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, the |
||
is called. | ||
num_oov_tokens: Non-negative integer. The number of out-of-vocab tokens. | ||
All out-of-vocab inputs will be assigned IDs in the range of | ||
[0, num_oov_tokens) based on a hash. When | ||
`vocabulary` is None, it will convert inputs in [0, num_oov_tokens) | ||
vocabulary: the vocabulary to lookup the input. If it is a file, it represents the | ||
source vocab file; If it is a list/tuple, it represents the source vocab | ||
list; If it is None, the vocabulary can later be set. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the format of the file? how do you set the vocabulary later? what is the expected use of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My point is that this should be documented. Stating the the vocabulary can be set later without showing how is not useful. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
name: Name to give to the layer. | ||
**kwargs: Keyword arguments to construct a layer. | ||
|
||
Input: a string or int tensor of shape `[batch_size, d1, ..., dm]` | ||
Output: an int tensor of shape `[batch_size, d1, ..., dm]` | ||
|
||
Example: | ||
|
||
If one input sample is `["a", "c", "d", "a", "x"]` and the vocabulary is ["a", "b", "c", "d"], | ||
and a single OOV token is used (`num_oov_tokens=1`), then the corresponding output sample is | ||
`[1, 3, 4, 1, 0]`. 0 stands for an OOV token. | ||
""" | ||
pass | ||
|
||
`tf.keras.layers.CategoryCrossing` | ||
CategoryCrossing(PreprocessingLayer): | ||
"""This layer transforms multiple categorical inputs to categorical outputs | ||
by Cartesian product. and hash the output if necessary. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove extra There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
If any input is sparse, then output is sparse, otherwise dense.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OOC, why is the wording here different than in the other API endpoints (it seems that the intended behavior is the same?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. This is the only layer that can accept multiple inputs. Other API only accept a single Tensor/SparseTensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. Maybe you should say, "If any of the inputs is sparse, then all outputs will be sparse. Otherwise, all outputs will be dense." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's better. Done. |
||
|
||
def __init__(self, depth=None, num_bins=None, name=None, **kwargs): | ||
"""Constructs a CategoryCrossing layer. | ||
Args: | ||
depth: depth of input crossing. By default None, all inputs are crossed | ||
into one output. It can be an int or tuple/list of ints, where inputs are | ||
combined into all combinations of output with degree of `depth`. For example, | ||
with inputs `a`, `b` and `c`, `depth=2` means the output will be [ab;ac;bc] | ||
Comment on lines
+239
to
+240
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the example should be moved to the "Example" section below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both Example for each layer description, and a code snippet below. |
||
num_bins: Number of hash bins. By default None, no hashing is performed. | ||
name: Name to give to the layer. | ||
**kwargs: Keyword arguments to construct a layer. | ||
|
||
Input: a list of int tensors of shape `[batch_size, d1, ..., dm]` | ||
Output: a single int tensor of shape `[batch_size, d1, ..., dm]` | ||
|
||
Example: | ||
|
||
If the layer receives two inputs, `a=[[1, 2]]` and `b=[[1, 3]]`, | ||
and if depth is 2, then | ||
the output will be a single integer tensor `[[i, j, k, l]]`, where: | ||
i is the index of the category "a1=1 and b1=1" | ||
j is the index of the category "a1=1 and b2=3" | ||
k is the index of the category "a2=2 and b1=1" | ||
l is the index of the category "a2=2 and b2=3" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this example. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it is confusing. Updated. |
||
""" | ||
pass | ||
|
||
`tf.keras.layers.CategoryEncoding` | ||
CategoryEncoding(PreprocessingLayer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add example for this layer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
"""This layer transforms categorical inputs from index space to category space. | ||
If input is dense/sparse, then output is dense/sparse.""" | ||
|
||
def __init__(self, num_categories, mode="sum", axis=-1, name=None, **kwargs): | ||
"""Constructs a CategoryEncoding layer. | ||
Args: | ||
num_categories: Number of elements in the vocabulary. | ||
mode: how to reduce a categorical input if multivalent, can be one of "sum", | ||
"mean", "binary", "tfidf". It can also be None if this is not a multivalent input, | ||
and simply needs to convert input from index space to category space. "tfidf" is only | ||
valid when adapt is called on this layer. | ||
axis: the axis to reduce, by default will be the last axis, specially true | ||
for sequential feature columns. | ||
name: Name to give to the layer. | ||
**kwargs: Keyword arguments to construct a layer. | ||
|
||
Input: a int tensor of shape `[batch_size, d1, ..., dm-1, dm]` | ||
Output: a float tensor of shape `[batch_size, d1, ..., dm-1, num_categories]` | ||
""" | ||
pass | ||
|
||
`tf.keras.layers.CategoryHashing` | ||
CategoryHashing(PreprocessingLayer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add example for this layer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
"""This layer transforms categorical inputs to hashed output. | ||
If input is dense/sparse, then output is dense/sparse.""" | ||
def __init__(self, num_bins, name=None, **kwargs): | ||
"""Constructs a CategoryHashing layer. | ||
|
||
Args: | ||
num_bins: Number of hash bins. | ||
name: Name to give to the layer. | ||
**kwargs: Keyword arguments to construct a layer. | ||
|
||
Input: a int tensor of shape `[batch_size, d1, ..., dm]` | ||
Output: a int tensor of shape `[batch_size, d1, ..., dm]` | ||
""" | ||
pass | ||
|
||
``` | ||
|
||
We also propose a `to_sparse` op to convert dense tensors to sparse tensors given user specified ignore values. This op can be used in both `tf.data` or [TF Transform](https://www.tensorflow.org/tfx/transform/get_started). In previous feature column world, "" is ignored for dense string input and -1 is ignored for dense int input. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't need the functionality of WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realized, we already have from_dense, so perhaps you should just extend it with the option to set the element to be ignore? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah good point. I wasn't aware of this op. We should just extend it. Done. |
||
|
||
```python | ||
`tf.to_sparse` | ||
def to_sparse(input, ignore_value): | ||
"""Convert dense/sparse tensor to sparse while dropping user specified values. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the benefit of calling this API with a SparseTensor input? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to allow users to filter specified values, e.g., if the original input is already sparse: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This filtering can be built out of existing operations. You can call, tf.where on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. Let's just extend the |
||
|
||
Args: | ||
input: A `Tensor` or `SparseTensor`. | ||
ignore_value: The value to be dropped from input. | ||
""" | ||
pass | ||
``` | ||
|
||
## Code Snippets | ||
|
||
Below is a more detailed illustration of how each layer works. If there is a vocabulary list of countries: | ||
```python | ||
vocabulary_list = ["Italy", "France", "England", "Austria", "Germany"] | ||
inp = np.asarray([["Italy", "Italy"], ["Germany", ""]]) | ||
sp_inp = tf.to_sparse(inp, "") | ||
cat_layer = tf.keras.layers.CategoryLookup(vocabulary=vocabulary_list) | ||
sp_out = cat_layer(sp_inp) | ||
``` | ||
|
||
The categorical layer will first convert the input to: | ||
```python | ||
sp_out.indices = <tf.Tensor: id=8, shape=(3, 2), dtype=int64, numpy= | ||
array([[0, 0], [0, 1] [1, 0]])> | ||
sp_out.values = <tf.Tensor: id=28, shape=(3,), dtype=int64, | ||
numpy=array([0, 0, 4])> | ||
``` | ||
|
||
The `CategoryEncoding` layer will then convert the input from index space to category space, e.g., from a sparse tensor with indices shape as [batch_size, n_columns] and values in the range of [0, n_categories) to a sparse tensor with indices shape as [batch_size, n_categories] and values as the frequency of each value that occured in the example: | ||
```python | ||
encoding_layer = CategoryEncoding(num_categories=len(vocabulary_list)) | ||
sp_encoded_out = encoding_layer(sp_out) | ||
sp_encoded_out.indices = <tf.Tensor: id=8, shape=(2, 2), dtype=int64, numpy= | ||
array([[0, 0], [1, 4]])> | ||
sp_encoded_out.values = <tf.Tensor: id=28, shape=(3,), dtype=int64, | ||
numpy=array([2., 1.])> | ||
``` | ||
A weight input can also be passed into the layer if different categories/examples should be treated differently. | ||
|
||
If this input needs to be crossed with another categorical input, say a vocabulary list of days, then use `CategoryCrossing` which works in the same way as `tf.feature_column.crossed_column` without setting `depth`: | ||
```python | ||
days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"] | ||
inp_days = tf.to_sparse(np.asarray([["Sunday"], [""]]), ignore_value="") | ||
layer_days = CategoryLookup(vocabulary=days) | ||
sp_out_2 = layer_days(inp_days) | ||
|
||
sp_out_2.indices = <tf.Tensor: id=161, shape=(1, 2), dtype=int64, numpy=array([[0, 0]])> | ||
sp_out_2.values = <tf.Tensor: id=181, shape=(1,), dtype=int64, numpy=array([6])> | ||
|
||
cross_layer = CategoryCrossing(num_bins=5) | ||
# Use the output from CategoryLookup (sp_out), not CategoryEncoding (sp_combined_out) | ||
crossed_out = cross_layer([sp_out, sp_out_2]) | ||
|
||
cross_out.indices = <tf.Tensor: id=186, shape=(2, 2), dtype=int64, numpy= | ||
array([[0, 0], [0, 1]])> | ||
cross_out.values = <tf.Tensor: id=187, shape=(2,), dtype=int64, numpy=array([3, 3])> | ||
``` |
Uh oh!
There was an error while loading. Please reload this page.