Skip to content

Commit fb57aaf

Browse files
author
Dilawar Singh
committed
overloading template function is not so straight-forward when
typecasting can happen by itslf e.g. int and unsigned int.
1 parent 22aeb43 commit fb57aaf

File tree

6 files changed

+67
-51
lines changed

6 files changed

+67
-51
lines changed

pybind11/MooseVec.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,34 @@ ObjId MooseVec::getFieldItem(const size_t i) const
9494
return ObjId(oid_.path(), oid_.dataIndex, i);
9595
}
9696

97-
vector<py::object> MooseVec::getAttribute(const string& name)
98-
{
97+
py::object MooseVec::getAttribute(const string& name)
98+
{
99+
// If type if double, int, bool etc, then return the numpy array. else
100+
// return the list of python object.
101+
auto cinfo = oid_.element()->cinfo();
102+
auto finfo = cinfo->findFinfo(name);
103+
auto rttType = finfo->rttiType();
104+
105+
if(rttType == "double")
106+
return getAttributeNumpy<double>(name);
107+
if(rttType == "unsigned int")
108+
return getAttributeNumpy<unsigned int>(name);
109+
if(rttType == "int")
110+
return getAttributeNumpy<unsigned int>(name);
111+
112+
// FIXME: bool type is not working. Need to raise the ticket on pybind11
113+
// after creating and MWE.
114+
//if(rttType == "bool")
115+
// return getAttributeNumpy<bool>(name);
116+
99117
vector<py::object> res(size());
100118
for (unsigned int i = 0; i < size(); i++)
101119
res[i] = getFieldGeneric(getItem(i), name);
102-
return res;
120+
return py::cast(res);
103121
}
104122

105123

124+
106125
// // FIXME: Only double is supported here. Not sure if this is enough. This
107126
// // should be the API function.
108127
// py::array_t<double> MooseVec::getAttributeNumpy(const string &name)

pybind11/MooseVec.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,11 @@ class MooseVec
7676

7777

7878
// Get attributes.
79-
vector<py::object> getAttribute(const string& name);
79+
py::object getAttribute(const string& name);
8080

81-
// TODO: May be buffer https://pybind11.readthedocs.io/en/stable/advanced/pycpp/numpy.html#buffer-protocol
82-
template<typename T=double>
81+
vector<ObjId> objs() const;
82+
83+
template<typename T>
8384
py::array_t<T> getAttributeNumpy(const string& name)
8485
{
8586
vector<T> res(size());
@@ -88,9 +89,6 @@ class MooseVec
8889
return py::array_t<T>(res.size(), res.data());
8990
}
9091

91-
vector<ObjId> objs() const;
92-
93-
9492
ObjId connectToSingle(const string& srcfield, const ObjId& tgt, const string& tgtfield, const string& msgtype);
9593

9694
ObjId connectToVec(const string& srcfield, const MooseVec& tgt, const string& tgtfield, const string& msgtype);

pybind11/helper.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,9 @@ inline ObjId mooseObjIdPath(const string& p)
5656
string path(p);
5757

5858
// If path is a relative path.
59-
if (p[0] != '/')
60-
{
59+
if (p[0] != '/') {
6160
string cwepath(mooseGetCweId().path());
62-
if(cwepath.back() != '/')
61+
if (cwepath.back() != '/')
6362
cwepath.push_back('/');
6463
path = cwepath + p;
6564
}

pybind11/pymoose.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,9 @@ PYBIND11_MODULE(_moose, m)
309309
.def("__setattr__", &MooseVec::setAttrOneToAll<double>)
310310
.def("__setattr__", &MooseVec::setAttrOneToAll<string>)
311311
.def("__setattr__", &MooseVec::setAttrOneToAll<bool>)
312-
.def("__getattr__", &MooseVec::getAttributeNumpy<double>)
313-
.def("__getattr__", &MooseVec::getAttributeNumpy<float>)
314-
// These three are probably never needed.
315-
.def("__getattr__", &MooseVec::getAttributeNumpy<unsigned int>)
316-
.def("__getattr__", &MooseVec::getAttributeNumpy<unsigned long>)
317-
// For the rest non-POD types.
312+
// Beware of pybind11 overload resolution order:
313+
// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#overload-resolution-order
314+
// Templated function won't work here. The first one is always called.
318315
.def("__getattr__", &MooseVec::getAttribute)
319316
.def("__repr__", [](const MooseVec & v)->string {
320317
return "<moose.vec class=" + v.dtype() + " path=" + v.path() +

pybind11/pymoose.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline T getField(const ObjId& id, const string& fname)
3636

3737
// FIXME: Is it most efficient?
3838
// See discussion here: https://github.com/pybind/pybind11/issues/1042
39-
template <typename T = double>
39+
template <typename T>
4040
inline py::array_t<T> getFieldNumpy(const ObjId& id, const string& fname)
4141
{
4242
auto v = Field<vector<T>>::get(id, fname);

tests/py_moose/test_connectionLists.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,11 @@
4545

4646
def makeGlobalBalanceNetwork():
4747
stim = moose.RandSpike( '/model/stim', params['numInputs'] )
48-
4948
inhib = moose.LIF( '/model/inhib', params['numInhib'] )
50-
5149
insyn = moose.SimpleSynHandler(inhib.path + '/syns', params['numInhib'])
52-
5350
moose.connect( insyn, 'activationOut', inhib, 'activation', 'OneToOne' )
5451
output = moose.LIF( '/model/output', params['numOutput'] )
55-
5652
outsyn = moose.SimpleSynHandler(output.path+'/syns',params['numOutput'])
57-
5853
moose.connect(outsyn, 'activationOut', output, 'activation', 'OneToOne')
5954
outInhSyn = moose.SimpleSynHandler(output.path+'/inhsyns',params['numOutput'])
6055
moose.connect(outInhSyn, 'activationOut', output, 'activation', 'OneToOne')
@@ -63,41 +58,47 @@ def makeGlobalBalanceNetwork():
6358
ov = moose.vec( outsyn.path + '/synapse' )
6459
oiv = moose.vec( outInhSyn.path + '/synapse' )
6560

66-
inhibMatrix = moose.connect( stim, 'spikeOut', iv, 'addSpike', 'Sparse' )
67-
# inhibMatrix = moose.element(temp)
61+
assert len(iv) == 0
62+
assert len(ov) == 0
63+
assert len(oiv) == 0
6864

69-
inhibMatrix.setRandomConnectivity(params['stimToInhProb'], params['stimToInhSeed'])
65+
temp = moose.connect( stim, 'spikeOut', iv, 'addSpike', 'Sparse' )
66+
inhibMatrix = moose.element( temp )
67+
inhibMatrix.setRandomConnectivity(
68+
params['stimToInhProb'], params['stimToInhSeed'] )
7069
cl = inhibMatrix.connectionList
7170

7271
# This can change when random-number generator changes.
72+
# This was before we used c++11 <random> to generate random numbers. This
73+
# test has changes on Tuesday 31 July 2018 11:12:35 AM IST
7374
# expectedCl = [ 1,4,13,13,26,42,52,56,80,82,95,97,4,9,0,9,4,8,0,6,1,6,6,7]
7475
expectedCl=[0,6,47,50,56,67,98,2,0,3,5,4,8,3]
7576

76-
# print('CL', cl)
7777
assert list(cl) == expectedCl, "Expected %s, got %s" % (expectedCl, cl)
78-
assert inhibMatrix.numEntries == 7, inhibMatrix.numEntries
7978

80-
excMatrix = moose.connect( stim, 'spikeOut', ov, 'addSpike', 'Sparse' )
81-
# excMatrix = moose.element(temp)
82-
print('111', excMatrix)
83-
excMatrix.setRandomConnectivity(params['stimToOutProb'], params['stimToOutSeed'])
84-
assert excMatrix.numEntries == 62, excMatrix.numEntries
79+
temp = moose.connect( stim, 'spikeOut', ov, 'addSpike', 'Sparse' )
80+
excMatrix = moose.element( temp )
81+
excMatrix.setRandomConnectivity(
82+
params['stimToOutProb'], params['stimToOutSeed'] )
8583

86-
negFFMatrix = moose.connect(inhib, 'spikeOut', oiv, 'addSpike', 'Sparse')
87-
# negFFMatrix = moose.element( temp )
88-
negFFMatrix.setRandomConnectivity(params['inhToOutProb'], params['inhToOutSeed'] )
89-
assert negFFMatrix.numEntries == 55, negFFMatrix.numEntries
84+
temp = moose.connect( inhib, 'spikeOut', oiv, 'addSpike', 'Sparse' )
85+
negFFMatrix = moose.element( temp )
86+
negFFMatrix.setRandomConnectivity(
87+
params['inhToOutProb'], params['inhToOutSeed'] )
88+
89+
# print("ConnMtxEntries: ", inhibMatrix.numEntries, excMatrix.numEntries, negFFMatrix.numEntries)
90+
got = (inhibMatrix.numEntries, excMatrix.numEntries, negFFMatrix.numEntries)
91+
expected = (7, 62, 55)
92+
assert expected == got, "Expected %s, Got %s" % (expected,got)
9093

9194
cl = negFFMatrix.connectionList
9295
numInhSyns = [ ]
9396
niv = 0
9497
nov = 0
9598
noiv = 0
96-
97-
insyn = moose.vec(insyn)
98-
for i in insyn:
99+
for i in moose.vec( insyn ):
99100
niv += i.synapse.num
100-
numInhSyns.append( i.synapse.num)
101+
numInhSyns.append( i.synapse.num )
101102
if i.synapse.num > 0:
102103
i.synapse.weight = params['wtStimToInh']
103104

@@ -106,28 +107,30 @@ def makeGlobalBalanceNetwork():
106107
assert numInhSyns == expected, "Expected %s, got %s" % (expected,numInhSyns)
107108

108109
for i in moose.vec( outsyn ):
110+
print('111', i)
109111
nov += i.synapse.num
110112
if i.synapse.num > 0:
111113
i.synapse.weight = params['wtStimToOut']
112114
for i in moose.vec( outInhSyn ):
113115
noiv += i.synapse.num
116+
#print i.synapse.num
114117
if i.synapse.num > 0:
115118
i.synapse.weight = params['wtInhToOut']
116-
117-
print(iv.numField, ov.numField, oiv.numField)
119+
118120
print("SUMS: ", sum( iv.numField ), sum( ov.numField ), sum( oiv.numField ))
119-
# assert [1, 64, 25] == [sum( iv.numField ), sum( ov.numField ), sum( oiv.numField )]
120-
assert [1, 50, 27] == [sum( iv.numField ), sum( ov.numField ), sum( oiv.numField )]
121+
assert [1, 64, 25] == [sum( iv.numField ), sum( ov.numField ), sum( oiv.numField )]
121122
print("SUMS2: ", niv, nov, noiv)
122-
# assert [7, 62, 55] == [ niv, nov, noiv ]
123123
assert [7, 62, 55] == [ niv, nov, noiv ]
124-
125-
print(insyn.vec)
126-
print(outsyn.vec)
127-
print(outInhSyn)
128124
print("SUMS3: ", sum( insyn.vec.numSynapses ), sum( outsyn.vec.numSynapses ), sum( outInhSyn.vec.numSynapses ))
129125
assert [7,62,55] == [ sum( insyn.vec.numSynapses ), sum( outsyn.vec.numSynapses ), sum( outInhSyn.vec.numSynapses ) ]
130126

127+
# print(oiv.numField)
128+
# print(insyn.vec[1].synapse.num)
129+
# print(insyn.vec.numSynapses)
130+
# print(sum( insyn.vec.numSynapses ))
131+
# niv = iv.numSynapses
132+
# ov = iv.numSynapses
133+
131134
sv = moose.vec( stim )
132135
sv.rate = params['randInputRate']
133136
sv.refractT = params['randRefractTime']

0 commit comments

Comments
 (0)