Skip to content

Fix kappa #1047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions tensorflow_addons/metrics/cohens_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,26 @@ def __init__(self,
num_classes: FloatTensorLike,
name: str = 'cohen_kappa',
weightage: Optional[str] = None,
sparse_labels: bool = False,
regression: bool = False,
dtype: AcceptableDTypes = None,
**kwargs):
"""Creates a `CohenKappa` instance.

Args:
num_classes: Number of unique classes in your dataset.
name: (Optional) String name of the metric instance.
weightage: (Optional) Weighting to be considered for calculating
weightage: (optional) Weighting to be considered for calculating
kappa statistics. A valid value is one of
[None, 'linear', 'quadratic']. Defaults to `None`.
dtype: (Optional) Data type of the metric result.
Defaults to `None`.
[None, 'linear', 'quadratic']. Defaults to `None`
sparse_lables: (bool) Valid only for multi-class scenario.
If True, ground truth labels are expected tp be integers
and not one-hot encoded
regression: (bool) If set, that means the problem is being treated
as a regression problem where you are regressing the predictions.
**Note:** If you are regressing for the values, the the output layer
should contain a single unit.
name: (optional) String name of the metric instance
dtype: (optional) Data type of the metric result. Defaults to `None`

Raises:
ValueError: If the value passed for `weightage` is invalid
Expand All @@ -89,8 +97,18 @@ def __init__(self,
if weightage not in (None, 'linear', 'quadratic'):
raise ValueError("Unknown kappa weighting type.")

if num_classes == 2:
self._update = self._update_binary_class_model
elif num_classes > 2:
self._update = self._update_multi_class_model
else:
raise ValueError("""Number of classes must be
greater than or euqal to two""")

self.weightage = weightage
self.num_classes = num_classes
self.regression = regression
self.sparse_labels = sparse_labels
self.conf_mtx = self.add_weight(
'conf_mtx',
shape=(self.num_classes, self.num_classes),
Expand All @@ -114,22 +132,42 @@ def update_state(self, y_true, y_pred, sample_weight=None):
Returns:
Update op.
"""
return self._update(y_true, y_pred, sample_weight)

def _update_binary_class_model(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, dtype=tf.int64)
y_pred = tf.cast(y_pred, dtype=tf.int64)
y_pred = tf.cast(y_pred, dtype=tf.float32)
y_pred = tf.cast(y_pred > 0.5, dtype=tf.int64)
return self._update_confusion_matrix(y_true, y_pred, sample_weight)

def _update_multi_class_model(self, y_true, y_pred, sample_weight=None):
if not self.sparse_labels:
y_true = tf.cast(tf.argmax(y_true, axis=-1), dtype=tf.int64)
else:
y_true = tf.cast(y_true, dtype=tf.int64)

if tf.rank(y_pred) > 1:
if not self.regression:
y_pred = tf.cast(tf.argmax(y_pred, axis=-1), dtype=tf.int64)
else:
y_pred = tf.math.round(tf.math.abs(y_pred))
y_pred = tf.cast(y_pred, dtype=tf.int64)
else:
y_pred = tf.cast(y_pred, dtype=tf.int64)

return self._update_confusion_matrix(y_true, y_pred, sample_weight)

if y_true.shape != y_pred.shape:
raise ValueError(
"Number of samples in `y_true` and `y_pred` are different")
def _update_confusion_matrix(self, y_true, y_pred, sample_weight):
y_true = tf.squeeze(y_true)
y_pred = tf.squeeze(y_pred)

# compute the new values of the confusion matrix
new_conf_mtx = tf.math.confusion_matrix(
labels=y_true,
predictions=y_pred,
num_classes=self.num_classes,
weights=sample_weight,
dtype=tf.float32)

# update the values in the original confusion matrix
return self.conf_mtx.assign_add(new_conf_mtx)

def result(self):
Expand Down Expand Up @@ -179,6 +217,8 @@ def get_config(self):
config = {
"num_classes": self.num_classes,
"weightage": self.weightage,
"sparse_labels": self.sparse_labels,
"regression": self.regression
}
base_config = super().get_config()
return {**base_config, **config}
Expand Down
80 changes: 77 additions & 3 deletions tensorflow_addons/metrics/cohens_kappa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Tests for Cohen's Kappa Metric."""

import numpy as np
import tensorflow as tf
from tensorflow_addons.metrics import CohenKappa
from tensorflow_addons.utils import test_utils
Expand All @@ -34,9 +35,9 @@ def test_config(self):
self.assertEqual(kp_obj.num_classes, 5)

def initialize_vars(self):
kp_obj1 = CohenKappa(num_classes=5)
kp_obj2 = CohenKappa(num_classes=5, weightage="linear")
kp_obj3 = CohenKappa(num_classes=5, weightage="quadratic")
kp_obj1 = CohenKappa(num_classes=5, sparse_labels=True)
kp_obj2 = CohenKappa(num_classes=5, sparse_labels=True, weightage="linear")
kp_obj3 = CohenKappa(num_classes=5, sparse_labels=True, weightage="quadratic")

self.evaluate(tf.compat.v1.variables_initializer(kp_obj1.variables))
self.evaluate(tf.compat.v1.variables_initializer(kp_obj2.variables))
Expand Down Expand Up @@ -147,12 +148,85 @@ def test_large_values(self):
y_true = [1] * 10000 + [0] * 20000 + [1] * 20000
y_pred = [0] * 20000 + [1] * 30000

y_true = tf.convert_to_tensor(y_true)
y_pred = tf.convert_to_tensor(y_pred)

obj = CohenKappa(num_classes=2)
self.evaluate(tf.compat.v1.variables_initializer(obj.variables))

self.evaluate(obj.update_state(y_true, y_pred))
self.assertAllClose(0.166666666, obj.result())

def test_with_sparse_labels(self):
y_true = np.array([4, 4, 3, 4], dtype=np.int32)
y_pred = np.array([4, 4, 1, 2], dtype=np.int32)

obj = CohenKappa(num_classes=5, sparse_labels=True)
self.evaluate(tf.compat.v1.variables_initializer(obj.variables))

self.evaluate(obj.update_state(y_true, y_pred))
self.assertAllClose(0.19999999, obj.result())

def test_with_ohe_labels(self):
y_true = np.array([4, 4, 3, 4], dtype=np.int32)
y_true = tf.keras.utils.to_categorical(y_true, num_classes=5)
y_pred = np.array([4, 4, 1, 2], dtype=np.int32)

obj = CohenKappa(num_classes=5, sparse_labels=False)
self.evaluate(tf.compat.v1.variables_initializer(obj.variables))

self.evaluate(obj.update_state(y_true, y_pred))
self.assertAllClose(0.19999999, obj.result())

def test_keras_binary_reg_model(self):
kp = CohenKappa(num_classes=2)
inputs = tf.keras.layers.Input(shape=(10,))
outputs = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.models.Model(inputs, outputs)
model.compile(optimizer="sgd", loss="mse", metrics=[kp])

x = np.random.rand(1000, 10).astype(np.float32)
y = np.random.randint(2, size=(1000, 1)).astype(np.float32)

model.fit(x, y, epochs=1, verbose=0, batch_size=32)

def test_keras_multiclass_reg_model(self):
kp = CohenKappa(num_classes=5, regression=True, sparse_labels=True)
inputs = tf.keras.layers.Input(shape=(10,))
outputs = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.models.Model(inputs, outputs)
model.compile(optimizer="sgd", loss="mse", metrics=[kp])

x = np.random.rand(1000, 10).astype(np.float32)
y = np.random.randint(5, size=(1000,)).astype(np.float32)

model.fit(x, y, epochs=1, verbose=0, batch_size=32)

def test_keras_binary_clasasification_model(self):
kp = CohenKappa(num_classes=2)
inputs = tf.keras.layers.Input(shape=(10,))
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(inputs)
model = tf.keras.models.Model(inputs, outputs)
model.compile(optimizer="sgd", loss="binary_crossentropy", metrics=[kp])

x = np.random.rand(1000, 10).astype(np.float32)
y = np.random.randint(2, size=(1000, 1)).astype(np.float32)

model.fit(x, y, epochs=1, verbose=0, batch_size=32)

def test_keras_multiclass_classification_model(self):
kp = CohenKappa(num_classes=5)
inputs = tf.keras.layers.Input(shape=(10,))
outputs = tf.keras.layers.Dense(5, activation="softmax")(inputs)
model = tf.keras.models.Model(inputs, outputs)
model.compile(optimizer="sgd", loss="categorical_crossentropy", metrics=[kp])

x = np.random.rand(1000, 10).astype(np.float32)
y = np.random.randint(5, size=(1000,)).astype(np.float32)
y = tf.keras.utils.to_categorical(y, num_classes=5)

model.fit(x, y, epochs=1, verbose=0, batch_size=32)


if __name__ == "__main__":
tf.test.main()