-
Notifications
You must be signed in to change notification settings - Fork 5.9k
optimize cross entropy kernel by using reduce. #4237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a3a8a09
141b8db
30bfaab
6735585
201c2bc
000d751
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ def setUp(self): | |
| dtype="float32") | ||
| self.inputs = {"X": X, "Label": label} | ||
| self.outputs = {"Y": cross_entropy} | ||
| self.attrs = {'soft_label': 0} | ||
| self.attrs = {"soft_label": 0} | ||
|
|
||
| def test_check_output(self): | ||
| self.check_output() | ||
|
|
@@ -34,24 +34,25 @@ class TestCrossEntropyOp2(OpTest): | |
|
|
||
| def setUp(self): | ||
| self.op_type = "cross_entropy" | ||
| batch_size = 10 | ||
| class_num = 5 | ||
| batch_size = 13 | ||
| class_num = 37 | ||
| X = np.random.uniform(0.1, 1.0, | ||
| [batch_size, class_num]).astype("float32") | ||
| label = np.random.uniform(0.1, 1.0, | ||
| [batch_size, class_num]).astype("float32") | ||
| label /= label.sum(axis=1, keepdims=True) | ||
| cross_entropy = (-label * np.log(X)).sum( | ||
| axis=1, keepdims=True).astype("float32") | ||
| self.inputs = {'X': X, 'Label': label} | ||
| self.outputs = {'Y': cross_entropy} | ||
| self.attrs = {'soft_label': 1} | ||
|
|
||
| self.inputs = {"X": X, "Label": label} | ||
| self.outputs = {"Y": cross_entropy} | ||
| self.attrs = {"soft_label": 1} | ||
|
|
||
| def test_check_output(self): | ||
| self.check_output() | ||
|
|
||
| def test_check_grad(self): | ||
| self.check_grad(['X'], 'Y') | ||
| self.check_grad(["X"], "Y", max_relative_error=0.05) | ||
|
|
||
|
|
||
| class TestCrossEntropyOp3(OpTest): | ||
|
|
@@ -61,8 +62,8 @@ class TestCrossEntropyOp3(OpTest): | |
|
|
||
| def setUp(self): | ||
| self.op_type = "cross_entropy" | ||
| batch_size = 30 | ||
| class_num = 10 | ||
| batch_size = 13 | ||
| class_num = 37 | ||
| X = np.random.uniform(0.1, 1.0, | ||
| [batch_size, class_num]).astype("float32") | ||
| label_index = np.random.randint( | ||
|
|
@@ -74,15 +75,15 @@ def setUp(self): | |
| dtype="float32") | ||
| cross_entropy2 = (-label * np.log(X)).sum( | ||
| axis=1, keepdims=True).astype("float32") | ||
| self.inputs = {'X': X, 'Label': label} | ||
| self.outputs = {'Y': cross_entropy} | ||
| self.attrs = {'soft_label': 1} | ||
| self.inputs = {"X": X, "Label": label} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://www.python.org/dev/peps/pep-0008/#string-quotes We following
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I know, only personally, I prefer to keep consistent in one file. |
||
| self.outputs = {"Y": cross_entropy} | ||
| self.attrs = {"soft_label": 1} | ||
|
|
||
| def test_check_output(self): | ||
| self.check_output() | ||
|
|
||
| def test_check_grad(self): | ||
| self.check_grad(['X'], 'Y') | ||
| self.check_grad(["X"], "Y", max_relative_error=0.05) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉line 123, 写成:
SoftCrossEntropyKernel<T, 512><<<d, block>>>There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.