Skip to content

Commit a3cc81c

Browse files
authored
move BART to its own module (#5058)
* move BART to its own module * add missing file
1 parent 70f1975 commit a3cc81c

File tree

10 files changed

+26
-9
lines changed

10 files changed

+26
-9
lines changed

Diff for: pymc/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __set_compiler_flags():
8383
to_inference_data,
8484
)
8585
from pymc.backends.tracetab import *
86+
from pymc.bart import *
8687
from pymc.blocking import *
8788
from pymc.data import *
8889
from pymc.distributions import *

Diff for: pymc/bart/__init__.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from pymc.bart.bart import BART
17+
from pymc.bart.pgbart import PGBART
18+
19+
__all__ = ["BART", "PGBART"]
File renamed without changes.

Diff for: pymc/step_methods/pgbart.py renamed to pymc/bart/pgbart.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from pandas import DataFrame, Series
2424

2525
from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements
26+
from pymc.bart.bart import BARTRV
27+
from pymc.bart.tree import LeafNode, SplitNode, Tree
2628
from pymc.blocking import RaveledVars
27-
from pymc.distributions.bart import BARTRV
28-
from pymc.distributions.tree import LeafNode, SplitNode, Tree
2929
from pymc.model import modelcontext
3030
from pymc.step_methods.arraystep import ArrayStepShared, Competence
3131

File renamed without changes.

Diff for: pymc/distributions/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
logpt_sum,
2323
)
2424

25-
from pymc.distributions.bart import BART
2625
from pymc.distributions.bound import Bound
2726
from pymc.distributions.continuous import (
2827
AsymmetricLaplace,
@@ -190,7 +189,6 @@
190189
"Rice",
191190
"Moyal",
192191
"Simulator",
193-
"BART",
194192
"CAR",
195193
"PolyaGamma",
196194
"logpt",

Diff for: pymc/sampling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141
from pymc.backends.arviz import _DefaultTrace
4242
from pymc.backends.base import BaseTrace, MultiTrace
4343
from pymc.backends.ndarray import NDArray
44+
from pymc.bart.pgbart import PGBART
4445
from pymc.blocking import DictToArrayBijection
4546
from pymc.distributions import NoDistribution
4647
from pymc.exceptions import IncorrectArgumentsError, SamplingError
4748
from pymc.model import Model, Point, modelcontext
4849
from pymc.parallel_sampling import Draw, _cpu_count
4950
from pymc.step_methods import (
5051
NUTS,
51-
PGBART,
5252
BinaryGibbsMetropolis,
5353
BinaryMetropolis,
5454
CategoricalGibbsMetropolis,

Diff for: pymc/step_methods/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,4 @@
3535
MetropolisMLDA,
3636
RecursiveDAProposal,
3737
)
38-
from pymc.step_methods.pgbart import PGBART
3938
from pymc.step_methods.slicer import Slice

Diff for: pymc/step_methods/hmc/nuts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pymc.aesaraf import floatX
2020
from pymc.backends.report import SamplerWarning, WarningType
21-
from pymc.distributions.bart import BARTRV
21+
from pymc.bart.bart import BARTRV
2222
from pymc.math import logbern, logdiffexp_numpy
2323
from pymc.step_methods.arraystep import Competence
2424
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData

Diff for: pymc/tests/test_bart.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def test_split_node():
10-
split_node = pm.distributions.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
10+
split_node = pm.bart.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
1111
assert split_node.index == 5
1212
assert split_node.idx_split_variable == 2
1313
assert split_node.split_value == 3.0
@@ -18,7 +18,7 @@ def test_split_node():
1818

1919

2020
def test_leaf_node():
21-
leaf_node = pm.distributions.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3])
21+
leaf_node = pm.bart.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3])
2222
assert leaf_node.index == 5
2323
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
2424
assert leaf_node.value == 3.14

0 commit comments

Comments
 (0)