-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add image classification unit test using simplified fluid API #10306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
cs2be
merged 2 commits into
PaddlePaddle:develop
from
cs2be:SIMPLIFY_IMAGE_RECOGNITION_TEST
May 1, 2018
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
145 changes: 145 additions & 0 deletions
145
python/paddle/fluid/tests/book/image_classification/notest_image_classification_resnet.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import print_function | ||
|
|
||
| import paddle | ||
| import paddle.fluid as fluid | ||
| import numpy | ||
|
|
||
|
|
||
| def resnet_cifar10(input, depth=32): | ||
| def conv_bn_layer(input, | ||
| ch_out, | ||
| filter_size, | ||
| stride, | ||
| padding, | ||
| act='relu', | ||
| bias_attr=False): | ||
| tmp = fluid.layers.conv2d( | ||
| input=input, | ||
| filter_size=filter_size, | ||
| num_filters=ch_out, | ||
| stride=stride, | ||
| padding=padding, | ||
| act=None, | ||
| bias_attr=bias_attr) | ||
| return fluid.layers.batch_norm(input=tmp, act=act) | ||
|
|
||
| def shortcut(input, ch_in, ch_out, stride): | ||
| if ch_in != ch_out: | ||
| return conv_bn_layer(input, ch_out, 1, stride, 0, None) | ||
| else: | ||
| return input | ||
|
|
||
| def basicblock(input, ch_in, ch_out, stride): | ||
| tmp = conv_bn_layer(input, ch_out, 3, stride, 1) | ||
| tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) | ||
| short = shortcut(input, ch_in, ch_out, stride) | ||
| return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') | ||
|
|
||
| def layer_warp(block_func, input, ch_in, ch_out, count, stride): | ||
| tmp = block_func(input, ch_in, ch_out, stride) | ||
| for i in range(1, count): | ||
| tmp = block_func(tmp, ch_out, ch_out, 1) | ||
| return tmp | ||
|
|
||
| assert (depth - 2) % 6 == 0 | ||
| n = (depth - 2) / 6 | ||
| conv1 = conv_bn_layer( | ||
| input=input, ch_out=16, filter_size=3, stride=1, padding=1) | ||
| res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) | ||
| res2 = layer_warp(basicblock, res1, 16, 32, n, 2) | ||
| res3 = layer_warp(basicblock, res2, 32, 64, n, 2) | ||
| pool = fluid.layers.pool2d( | ||
| input=res3, pool_size=8, pool_type='avg', pool_stride=1) | ||
| return pool | ||
|
|
||
|
|
||
| def inference_network(): | ||
| classdim = 10 | ||
| data_shape = [3, 32, 32] | ||
| images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') | ||
| net = resnet_cifar10(images, 32) | ||
| predict = fluid.layers.fc(input=net, size=classdim, act='softmax') | ||
| return predict | ||
|
|
||
|
|
||
| def train_network(): | ||
| predict = inference_network() | ||
| label = fluid.layers.data(name='label', shape=[1], dtype='int64') | ||
| cost = fluid.layers.cross_entropy(input=predict, label=label) | ||
| avg_cost = fluid.layers.mean(cost) | ||
| accuracy = fluid.layers.accuracy(input=predict, label=label) | ||
| return avg_cost, accuracy | ||
|
|
||
|
|
||
| def train(use_cuda, save_path): | ||
| BATCH_SIZE = 128 | ||
| EPOCH_NUM = 1 | ||
|
|
||
| train_reader = paddle.batch( | ||
| paddle.reader.shuffle( | ||
| paddle.dataset.cifar.train10(), buf_size=128 * 10), | ||
| batch_size=BATCH_SIZE) | ||
|
|
||
| test_reader = paddle.batch( | ||
| paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) | ||
|
|
||
| def event_handler(event): | ||
| if isinstance(event, fluid.EndIteration): | ||
| if (event.batch_id % 10) == 0: | ||
| avg_cost, accuracy = trainer.test(reader=test_reader) | ||
|
|
||
| print('BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}'.format( | ||
| event.batch_id + 1, avg_cost, accuracy)) | ||
|
|
||
| if accuracy > 0.01: # Low threshold for speeding up CI | ||
| trainer.params.save(save_path) | ||
| return | ||
|
|
||
| place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
| trainer = fluid.Trainer( | ||
| train_network, | ||
| optimizer=fluid.optimizer.Adam(learning_rate=0.001), | ||
| place=place, | ||
| event_handler=event_handler) | ||
| trainer.train(train_reader, EPOCH_NUM, event_handler=event_handler) | ||
|
|
||
|
|
||
| def infer(use_cuda, save_path): | ||
| params = fluid.Params(save_path) | ||
| place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
| inferencer = fluid.Inferencer(inference_network, params, place=place) | ||
|
|
||
| # The input's dimension of conv should be 4-D or 5-D. | ||
| # Use normilized image pixels as input data, which should be in the range | ||
| # [0, 1.0]. | ||
| tensor_img = numpy.random.rand(1, 3, 32, 32).astype("float32") | ||
| results = inferencer.infer({'pixel': tensor_img}) | ||
|
|
||
| print("infer results: ", results) | ||
|
|
||
|
|
||
| def main(use_cuda): | ||
| if use_cuda and not fluid.core.is_compiled_with_cuda(): | ||
| return | ||
| save_path = "image_classification_resnet.inference.model" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use save_path actually make the program easier to read. Nice! |
||
| train(use_cuda, save_path) | ||
| infer(use_cuda, save_path) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| for use_cuda in (False, True): | ||
| main(use_cuda=use_cuda) | ||
124 changes: 124 additions & 0 deletions
124
python/paddle/fluid/tests/book/image_classification/notest_image_classification_vgg.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import print_function | ||
|
|
||
| import paddle | ||
| import paddle.fluid as fluid | ||
| import numpy | ||
|
|
||
|
|
||
| def vgg16_bn_drop(input): | ||
| def conv_block(input, num_filter, groups, dropouts): | ||
| return fluid.nets.img_conv_group( | ||
| input=input, | ||
| pool_size=2, | ||
| pool_stride=2, | ||
| conv_num_filter=[num_filter] * groups, | ||
| conv_filter_size=3, | ||
| conv_act='relu', | ||
| conv_with_batchnorm=True, | ||
| conv_batchnorm_drop_rate=dropouts, | ||
| pool_type='max') | ||
|
|
||
| conv1 = conv_block(input, 64, 2, [0.3, 0]) | ||
| conv2 = conv_block(conv1, 128, 2, [0.4, 0]) | ||
| conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) | ||
| conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) | ||
| conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) | ||
|
|
||
| drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5) | ||
| fc1 = fluid.layers.fc(input=drop, size=4096, act=None) | ||
| bn = fluid.layers.batch_norm(input=fc1, act='relu') | ||
| drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) | ||
| fc2 = fluid.layers.fc(input=drop2, size=4096, act=None) | ||
| return fc2 | ||
|
|
||
|
|
||
| def inference_network(): | ||
| classdim = 10 | ||
| data_shape = [3, 32, 32] | ||
| images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') | ||
| net = vgg16_bn_drop(images) | ||
| predict = fluid.layers.fc(input=net, size=classdim, act='softmax') | ||
| return predict | ||
|
|
||
|
|
||
| def train_network(): | ||
| predict = inference_network() | ||
| label = fluid.layers.data(name='label', shape=[1], dtype='int64') | ||
| cost = fluid.layers.cross_entropy(input=predict, label=label) | ||
| avg_cost = fluid.layers.mean(cost) | ||
| accuracy = fluid.layers.accuracy(input=predict, label=label) | ||
| return avg_cost, accuracy | ||
|
|
||
|
|
||
| def train(use_cuda, save_path): | ||
| BATCH_SIZE = 128 | ||
| EPOCH_NUM = 1 | ||
|
|
||
| train_reader = paddle.batch( | ||
| paddle.reader.shuffle( | ||
| paddle.dataset.cifar.train10(), buf_size=128 * 10), | ||
| batch_size=BATCH_SIZE) | ||
|
|
||
| test_reader = paddle.batch( | ||
| paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) | ||
|
|
||
| def event_handler(event): | ||
| if isinstance(event, fluid.EndIteration): | ||
| if (event.batch_id % 10) == 0: | ||
| avg_cost, accuracy = trainer.test(reader=test_reader) | ||
|
|
||
| print('BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}'.format( | ||
| event.batch_id + 1, avg_cost, accuracy)) | ||
|
|
||
| if accuracy > 0.01: # Low threshold for speeding up CI | ||
| trainer.params.save(save_path) | ||
| return | ||
|
|
||
| place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
| trainer = fluid.Trainer( | ||
| train_network, | ||
| optimizer=fluid.optimizer.Adam(learning_rate=0.001), | ||
| place=place, | ||
| event_handler=event_handler) | ||
| trainer.train(train_reader, EPOCH_NUM, event_handler=event_handler) | ||
|
|
||
|
|
||
| def infer(use_cuda, save_path): | ||
| params = fluid.Params(save_path) | ||
| place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
| inferencer = fluid.Inferencer(inference_network, params, place=place) | ||
|
|
||
| # The input's dimension of conv should be 4-D or 5-D. | ||
| # Use normilized image pixels as input data, which should be in the range | ||
| # [0, 1.0]. | ||
| tensor_img = numpy.random.rand(1, 3, 32, 32).astype("float32") | ||
| results = inferencer.infer({'pixel': tensor_img}) | ||
|
|
||
| print("infer results: ", results) | ||
|
|
||
|
|
||
| def main(use_cuda): | ||
| if use_cuda and not fluid.core.is_compiled_with_cuda(): | ||
| return | ||
| save_path = "image_classification_vgg.inference.model" | ||
| train(use_cuda, save_path) | ||
| infer(use_cuda, save_path) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| for use_cuda in (False, True): | ||
| main(use_cuda=use_cuda) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Glad that we have batch_id return in the event.