Skip to content

Commit f448edf

Browse files
authored
Merge pull request #2610 from dzhwinter/go_optimizer
Go optimizer: integrate Go with optimizer library
2 parents f3c9789 + 85c4352 commit f448edf

File tree

15 files changed

+183
-300
lines changed

15 files changed

+183
-300
lines changed

go/pserver/cclient/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
12
go_library(paddle_pserver_cclient STATIC)
2-
3-
add_subdirectory(test)
3+
if(WITH_TESTING)
4+
add_subdirectory(test)
5+
endif()
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
2-
cc_binary(main SRCS main.c DEPS paddle_pserver_cclient)
31
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient)
2+
add_style_check_target(test_cclient test_cclient.c)

go/pserver/cclient/test/main.c

Lines changed: 0 additions & 93 deletions
This file was deleted.

go/pserver/cclient/test/test_cclient.c

Lines changed: 74 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -3,113 +3,101 @@
33

44
#include "libpaddle_pserver_cclient.h"
55

6-
typedef float real;
7-
8-
void fail() {
9-
// TODO(helin): fix: gtest using cmake is not working, using this
10-
// hacky way for now.
11-
printf("test failed.\n");
6+
// TODO(helin): Fix: gtest using cmake is not working, using this
7+
// hacky way for now.
8+
#define fail() \
9+
fprintf(stderr, "info: %s:%d: ", __FILE__, __LINE__); \
1210
exit(-1);
11+
12+
void sendGrads(paddle_pserver_client c) {
13+
unsigned char grad_a[2000] = {2};
14+
unsigned char grad_b[3000] = {3};
15+
paddle_gradient grad1 = {
16+
"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000};
17+
paddle_gradient grad2 = {
18+
"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000};
19+
paddle_gradient *grads[2] = {&grad1, &grad2};
20+
if (paddle_send_grads(c, grads, 2)) {
21+
fail();
22+
}
1323
}
1424

15-
void print_parameter(paddle_gradient* param) {
16-
if (param == NULL) {
17-
printf("param is NULL!!\n");
18-
} else {
19-
printf("==== parameter ====\n");
20-
printf("name: %s\n", param->name);
21-
printf("content_len: %d\n", param->content_len);
22-
printf("content_type: %d\n", param->element_type);
23-
int i;
24-
for (i = 0; i < param->content_len / (int)sizeof(real); ++i) {
25-
printf("%f ", ((float*)param->content)[i]);
26-
}
27-
printf("\n\n");
25+
void getParams(paddle_pserver_client c) {
26+
paddle_parameter param_a;
27+
paddle_parameter param_b;
28+
char name_a[] = "param_a";
29+
char name_b[] = "param_b";
30+
// Must pre-allocate the prameter content before calling paddle_get_params.
31+
unsigned char content_a[2000] = {};
32+
unsigned char content_b[3000] = {};
33+
param_a.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
34+
param_a.name = name_a;
35+
param_a.content = content_a;
36+
param_a.content_len = 2000;
37+
param_b.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
38+
param_b.name = name_b;
39+
param_b.content = content_b;
40+
param_b.content_len = 3000;
41+
42+
paddle_parameter *params[2] = {&param_a, &param_b};
43+
if (paddle_get_params(c, params, 2)) {
44+
fail();
2845
}
2946
}
3047

3148
int main() {
3249
char addr[] = "localhost:3000";
3350
paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
34-
35-
char* names[] = {"param_a", "param_b"};
36-
51+
char *config_proto;
52+
size_t config_proto_len = 0;
53+
ssize_t nread;
54+
FILE *fp = fopen("testdata/optimizer.pb.txt", "r");
55+
if (!fp) {
56+
fail();
57+
}
58+
while ((nread = getline(&config_proto, &config_proto_len, fp)) != -1) {
59+
printf("%s", config_proto);
60+
}
61+
fclose(fp);
3762
retry:
38-
printf("init parameter to pserver:\n");
39-
40-
real param_content1[] = {0.1, 0.2, 0.3};
41-
real param_content2[] = {0.4, 0.5, 0.6};
42-
paddle_parameter** params =
43-
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * 2);
44-
params[0] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
45-
params[0]->name = names[0];
46-
params[0]->content = (unsigned char*)param_content1;
47-
params[0]->content_len = 3 * sizeof(real);
48-
params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
49-
50-
params[1] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
51-
params[1]->name = names[1];
52-
params[1]->content = (unsigned char*)param_content2;
53-
params[1]->content_len = 3 * sizeof(real);
54-
params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
55-
5663
if (paddle_begin_init_params(c)) {
57-
if (paddle_init_param(c, *params[0], NULL, 0) != 0) {
64+
paddle_parameter param;
65+
char name_a[] = "param_a";
66+
char name_b[] = "param_b";
67+
unsigned char content_a[2000] = {1};
68+
unsigned char content_b[3000] = {0};
69+
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
70+
param.name = name_a;
71+
param.content = content_a;
72+
param.content_len = 2000;
73+
int error =
74+
paddle_init_param(c, param, (void *)config_proto, config_proto_len);
75+
if (error != 0) {
5876
goto retry;
5977
}
60-
if (paddle_init_param(c, *params[1], NULL, 0) != 0) {
78+
79+
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
80+
param.name = name_b;
81+
param.content = content_b;
82+
param.content_len = 3000;
83+
error = paddle_init_param(c, param, (void *)config_proto, config_proto_len);
84+
if (error != 0) {
6185
goto retry;
6286
}
63-
if (paddle_finish_init_params(c) != 0) {
87+
88+
error = paddle_finish_init_params(c);
89+
if (error != 0) {
6490
goto retry;
6591
}
66-
} else {
67-
fail();
68-
}
69-
70-
printf("get inited parameters from pserver:\n");
71-
// get parameters again by reusing the allocated parameter buffers.
72-
if (paddle_get_params(c, params, 2) != 0) {
73-
fail();
74-
}
75-
print_parameter(params[0]);
76-
print_parameter(params[1]);
77-
78-
printf("send gradient to pserver:\n");
79-
real gradient_content1[] = {0.01, 0.02, 0.03};
80-
real gradinet_content2[] = {0.04, 0.05, 0.06};
81-
82-
paddle_gradient** grads =
83-
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
84-
grads[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
85-
grads[0]->name = names[0];
86-
grads[0]->content = (unsigned char*)gradient_content1;
87-
grads[0]->content_len = 3 * sizeof(real);
88-
grads[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
89-
90-
grads[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
91-
grads[1]->name = names[1];
92-
grads[1]->content = (unsigned char*)gradinet_content2;
93-
grads[1]->content_len = 3 * sizeof(real);
94-
grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
95-
96-
printf("print gradient sent to pserver:\n");
97-
print_parameter(grads[0]);
98-
print_parameter(grads[1]);
99-
100-
if (paddle_send_grads(c, grads, 2) != 0) {
101-
fail();
10292
}
10393

104-
printf("get updated parameters from pserver:\n");
105-
// get parameters again by reusing the allocated parameter buffers.
106-
if (paddle_get_params(c, params, 2) != 0) {
107-
fail();
94+
int i;
95+
for (i = 0; i < 100; i++) {
96+
sendGrads(c);
97+
getParams(c);
10898
}
109-
print_parameter(params[0]);
110-
print_parameter(params[1]);
11199

112-
if (paddle_save_model(c, "/tmp/") != 0) {
100+
if (paddle_save_model(c, "/tmp/")) {
113101
fail();
114102
}
115103

go/pserver/cclient/test/test_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def main():
2222
# create optimizer
2323
optimizer = paddle.optimizer.Momentum(momentum=0)
2424

25+
#TODO(zhihong) : replace optimizer with new OptimizerConfig
26+
2527
trainer = paddle.trainer.SGD(cost=cost,
2628
parameters=parameters,
2729
update_equation=optimizer,
51 Bytes
Binary file not shown.

go/pserver/client_test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package pserver_test
22

33
import (
4+
"io/ioutil"
45
"net"
56
"net/http"
67
"net/rpc"
@@ -74,18 +75,22 @@ func TestClientFull(t *testing.T) {
7475
}
7576

7677
const numParameter = 100
78+
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
79+
if err != nil {
80+
t.Fatalf("read optimizer proto failed")
81+
}
7782
for i := 0; i < numParameter; i++ {
7883
var p pserver.Parameter
7984
p.Name = "p_" + strconv.Itoa(i)
8085
p.ElementType = pserver.Float32
8186
p.Content = make([]byte, (i+1)*100)
82-
err := c.InitParam(pserver.ParameterWithConfig{Param: p})
87+
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
8388
if err != nil {
8489
t.Fatal(err)
8590
}
8691
}
8792

88-
err := c.FinishInitParams()
93+
err = c.FinishInitParams()
8994
if err != nil {
9095
t.Fatal(err)
9196
}

go/pserver/optimizer.c

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)