Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions doc/design/python_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block
```python
class Program(objects):
def __init__(self):
self.proto = core.NewProgram() # a C++ ProgramDesc pointer.
self.desc = core.NewProgram() # a C++ ProgramDesc pointer.
self.blocks = vector<Block>()
self.blocks.append(Block(self, -1)) # the global block
self.current_block = 0 # initialized to the global block
Expand Down Expand Up @@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m
```python
class Block(objects):
def __init__(self, program, parent_idx):
self.proto = core.NewBlock(program.proto)
self.desc = core.NewBlock(program.desc)
self.program = program
self.vars = map<string, Variable>()
self.ops = vector<Operator>()
Expand Down Expand Up @@ -98,11 +98,11 @@ class Operator(object):
outputs,# dict<stirng, Variable>
attrs # dict<string, Any>
):
self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs)
core.infer_shape(self.proto, inputs, outputs)
self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs)
core.infer_shape(self.desc, inputs, outputs)

def type(self):
return self.proto.type()
return self.desc.type()
```

`Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++.
Expand All @@ -124,7 +124,7 @@ class Variable(object):
name = unique_name_generator()
self.name = name
self.block = block
self.proto = core.NewVarDesc(block.proto, name, shape, lod_level)
self.desc = core.NewVarDesc(block.desc, name, shape, lod_level)
self.writer = None
```

Expand Down
8 changes: 8 additions & 0 deletions paddle/framework/var_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,13 @@ std::vector<int64_t> VarDescBind::Shape() const {
DataType VarDescBind::GetDataType() const {
return desc_.lod_tensor().data_type();
}

void VarDescBind::SetLoDLevel(int32_t lod_level) {
desc_.mutable_lod_tensor()->set_lod_level(lod_level);
}

int32_t VarDescBind::GetLodLevel() const {
return desc_.lod_tensor().lod_level();
}
} // namespace framework
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/framework/var_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class VarDescBind {

DataType GetDataType() const;

void SetLoDLevel(int32_t lod_level);

int32_t GetLodLevel() const;

private:
VarDesc desc_;
};
Expand Down
4 changes: 3 additions & 1 deletion paddle/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ void BindVarDsec(py::module &m) {
.def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type", &VarDescBind::GetDataType);
.def("data_type", &VarDescBind::GetDataType)
.def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel);
}

void BindOpDesc(py::module &m) {
Expand Down
107 changes: 86 additions & 21 deletions python/paddle/v2/framework/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import paddle.v2.framework.core as core
import collections
import numpy as np

__all__ = ['Block', 'Variable', 'Program', 'Operator']

Expand All @@ -11,40 +12,104 @@ def __init__(self, block, name=None, shape=None, dtype=None,

if name is None:
name = Variable._unique_var_name_()
self.proto = self.block.proto.new_var(name)
try:
self.desc = self.block.desc.var(name)
is_new_var = False
except core.EnforceNotMet:
self.desc = self.block.desc.new_var(name)
is_new_var = True

if shape is not None:
self.proto.set_shape(shape)

if is_new_var:
self.desc.set_shape(shape)
else:
old_shape = self.shape
shape = tuple(shape)
if shape != old_shape:
raise ValueError(
"Variable {0} has been created before. the previous "
"shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape))
if dtype is not None:
# TODO(yuyang18): Convert dtype from numpy.dtype
self.proto.set_data_type(dtype)
if not isinstance(dtype, core.DataType):
dtype = Variable._convert_np_dtype_to_dtype_(dtype)
if is_new_var:
self.desc.set_data_type(dtype)
else:
old_dtype = self.data_type()
if dtype != old_shape:
raise ValueError("Variable {0} has been created before. "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
"matched.".format(self.name, old_dtype,
dtype))

if lod_level is not None:
# TODO(yuyang18): set_lod_level is not defined.
self.proto.set_lod_level(lod_level)

if is_new_var:
self.desc.set_lod_level(lod_level)
else:
if lod_level != self.lod_level:
raise ValueError("Variable {0} has been created before. "
"The previous lod_level is {1}; the new "
"lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level,
lod_level))
self.block.vars[name] = self
self.op = None

# TODO(yuyang18): Get methods
@property
def name(self):
return self.desc.name()

@property
def shape(self):
# convert to tuple, make it as same as numpy API.
return tuple(self.desc.shape())

@property
def data_type(self):
return self.desc.data_type()

@property
def lod_level(self):
return self.desc.lod_level()

@staticmethod
def _unique_var_name_():
uid = core.unique_integer() # unique during whole process.
return "_generated_var_%d" % uid

@staticmethod
def _convert_np_dtype_to_dtype_(np_dtype):
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.DataType.FP32
elif dtype == np.float64:
return core.DataType.FP64
elif dtype == np.float16:
return core.DataType.FP16
elif dtype == np.int32:
return core.DataType.INT32
elif dtype == np.int16:
return core.DataType.INT16
elif dtype == np.int64:
return core.DataType.INT64
elif dtype == np.bool:
return core.DataType.BOOL
else:
raise ValueError("Not supported numpy dtype " + str(dtype))


class Operator(object):
def __init__(self,
block,
proto,
desc,
type=None,
inputs=None,
outputs=None,
attrs=None):
self.block = block
self.proto = proto
self.desc = desc
if type is not None:
# TODO.
pass
Expand All @@ -58,36 +123,36 @@ def __init__(self,
# TODO
pass

# TODO: Getters
# TODO: Getters


class Block(object):
def __init__(self, program, idx):
self.proto = program.proto.block(idx)
self.desc = program.desc.block(idx)
self.vars = dict() # var_name --> var
self.ops = collections.deque() # operator list
self.program = program

@property
def parent_idx(self):
return self.proto.parent
return self.desc.parent

@property
def idx(self):
return self.proto.id
return self.desc.id

def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs)

def append_op(self, *args, **kwargs):
op_proto = self.proto.append_op()
op = Operator(self, op_proto, *args, **kwargs)
op_desc = self.desc.append_op()
op = Operator(self, op_desc, *args, **kwargs)
self.ops.append(op)
return op

def prepend_op(self, *args, **kwargs):
op_proto = self.proto.prepend_op()
op = Operator(self, op_proto, *args, **kwargs)
op_desc = self.desc.prepend_op()
op = Operator(self, op_desc, *args, **kwargs)
self.ops.appendleft(op)
return op

Expand All @@ -104,7 +169,7 @@ def instance(cls):
def __init__(self):
assert not hasattr(self.__class__,
'_instance'), 'Do not call constructor directly!'
self.proto = core.ProgramDesc.instance()
self.desc = core.ProgramDesc.instance()
self.blocks = [Block(self, 0)]
self.current_block_idx = 0

Expand All @@ -116,7 +181,7 @@ def current_block(self):

def create_block(self):
new_block_idx = len(self.blocks)
self.proto.append_block(self.current_block().proto)
self.desc.append_block(self.current_block().desc)
self.current_block_idx = new_block_idx
self.blocks.append(Block(self, self.current_block_idx))
return self.current_block()
Expand Down
40 changes: 40 additions & 0 deletions python/paddle/v2/framework/tests/test_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
from paddle.v2.framework.graph import Variable, g_program
import paddle.v2.framework.core as core
import numpy as np


class TestVariable(unittest.TestCase):
def test_np_dtype_convert(self):
DT = core.DataType
convert = Variable._convert_np_dtype_to_dtype_
self.assertEqual(DT.FP32, convert(np.float32))
self.assertEqual(DT.FP16, convert("float16"))
self.assertEqual(DT.FP64, convert("float64"))
self.assertEqual(DT.INT32, convert("int32"))
self.assertEqual(DT.INT16, convert("int16"))
self.assertEqual(DT.INT64, convert("int64"))
self.assertEqual(DT.BOOL, convert("bool"))
self.assertRaises(ValueError, lambda: convert("int8"))

def test_var(self):
b = g_program.current_block()
w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)

w = b.create_var(name='fc.w')
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)

self.assertRaises(ValueError,
lambda: b.create_var(name="fc.w", shape=(24, 100)))


if __name__ == '__main__':
unittest.main()