Skip to content

Commit 8cee538

Browse files
authored
Merge pull request #112 from philpep/ansible-vault
Various ansible fixes
2 parents 62b7d07 + f69c802 commit 8cee538

File tree

5 files changed

+114
-49
lines changed

5 files changed

+114
-49
lines changed

testinfra/backend/ansible.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import pprint
1919

2020
from testinfra.backend import base
21-
import testinfra.utils.ansible_runner as ansible_runner
2221

2322
logger = logging.getLogger("testinfra")
2423

@@ -30,8 +29,16 @@ class AnsibleBackend(base.BaseBackend):
3029
def __init__(self, host, ansible_inventory=None, *args, **kwargs):
3130
self.host = host
3231
self.ansible_inventory = ansible_inventory
32+
self._ansible_runner = None
3333
super(AnsibleBackend, self).__init__(host, *args, **kwargs)
3434

35+
@property
36+
def ansible_runner(self):
37+
if self._ansible_runner is None:
38+
from testinfra.utils.ansible_runner import AnsibleRunner
39+
self._ansible_runner = AnsibleRunner(self.ansible_inventory)
40+
return self._ansible_runner
41+
3542
def run(self, command, *args):
3643
command = self.get_command(command, *args)
3744
out = self.run_ansible("shell", module_args=command)
@@ -53,9 +60,8 @@ def run(self, command, *args):
5360
return result
5461

5562
def run_ansible(self, module_name, module_args=None, **kwargs):
56-
result = ansible_runner.run(
63+
result = self.ansible_runner.run(
5764
self.host, module_name, module_args,
58-
host_list=self.ansible_inventory,
5965
**kwargs)
6066
logger.info(
6167
"RUN Ansible(%s, %s, %s): %s",
@@ -64,9 +70,9 @@ def run_ansible(self, module_name, module_args=None, **kwargs):
6470
return result
6571

6672
def get_variables(self):
67-
return ansible_runner.get_variables(
68-
self.host, host_list=self.ansible_inventory)
73+
return self.ansible_runner.get_variables(self.host)
6974

7075
@classmethod
7176
def get_hosts(cls, host, **kwargs):
72-
return ansible_runner.get_hosts(kwargs.get("ansible_inventory"), host)
77+
from testinfra.utils.ansible_runner import AnsibleRunner
78+
return AnsibleRunner(kwargs.get("ansible_inventory")).get_hosts(host)

testinfra/test/conftest.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,57 @@ def has_docker():
4949
return _HAS_DOCKER
5050

5151

52-
def get_ansible_inventory(name, hostname, user, port, key):
52+
# Generated with
53+
# $ echo myhostvar: bar > hostvars.yml
54+
# $ echo polichinelle > vault-pass.txt
55+
# $ ansible-vault encrypt --vault-password-file vault-pass.txt hostvars.yml
56+
# $ cat hostvars.yml
57+
ANSIBLE_HOSTVARS = """$ANSIBLE_VAULT;1.1;AES256
58+
39396233323131393835363638373764336364323036313434306134636633353932623363646233
59+
6436653132383662623364313438376662666135346266370a343934663431363661393363386633
60+
64656261336662623036373036363535313964313538366533313334366363613435303066316639
61+
3235393661656230350a326264356530326432393832353064363439393330616634633761393838
62+
3261
63+
"""
64+
65+
66+
def setup_ansible_config(tmpdir, name, host, user, port, key):
5367
ansible_major_version = int(ansible.__version__.split(".", 1)[0])
5468
items = [
5569
name,
5670
"ansible_ssh_private_key_file={}".format(key),
71+
"myvar=foo",
5772
]
5873
if ansible_major_version == 1:
5974
items.extend([
60-
"ansible_ssh_host={}".format(hostname),
75+
"ansible_ssh_host={}".format(host),
6176
"ansible_ssh_user={}".format(user),
6277
"ansible_ssh_port={}".format(port),
6378
])
6479
elif ansible_major_version == 2:
6580
items.extend([
66-
"ansible_host={}".format(hostname),
81+
"ansible_host={}".format(host),
6782
"ansible_user={}".format(user),
6883
"ansible_port={}".format(port),
6984
])
70-
return " ".join(items) + "\n"
85+
tmpdir.join("inventory").write(
86+
"[testgroup]\n" + " ".join(items) + "\n")
87+
tmpdir.mkdir("host_vars").join(name).write(ANSIBLE_HOSTVARS)
88+
tmpdir.mkdir("group_vars").join("testgroup").write((
89+
"---\n"
90+
"myhostvar: should_be_overriden\n"
91+
"mygroupvar: qux\n"
92+
))
93+
vault_password_file = tmpdir.join("vault-pass.txt")
94+
vault_password_file.write("polichinelle\n")
95+
ansible_cfg = tmpdir.join("ansible.cfg")
96+
ansible_cfg.write((
97+
"[defaults]\n"
98+
"vault_password_file={}\n"
99+
"host_key_checking=False\n\n"
100+
"[ssh_connection]\n"
101+
"pipelining=True\n"
102+
).format(str(vault_password_file)))
71103

72104

73105
def build_docker_container_fixture(image, scope):
@@ -138,10 +170,11 @@ def TestinfraBackend(request, tmpdir_factory):
138170
if ansible is None:
139171
pytest.skip()
140172
return
141-
inventory = tmpdir.join("inventory")
142-
inventory.write(get_ansible_inventory(
143-
host, docker_host, user or "root", port, str(key)))
144-
kw["ansible_inventory"] = str(inventory)
173+
setup_ansible_config(
174+
tmpdir, host, docker_host, user or "root", port, str(key))
175+
os.environ["ANSIBLE_CONFIG"] = str(tmpdir.join("ansible.cfg"))
176+
# this force backend cache reloading
177+
kw["ansible_inventory"] = str(tmpdir.join("inventory"))
145178
else:
146179
ssh_config = tmpdir.join("ssh_config")
147180
ssh_config.write((

testinfra/test/test_backends.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,15 @@ def test_user_connection(User):
6666
@pytest.mark.testinfra_hosts(*SUDO_HOSTS)
6767
def test_sudo(User):
6868
assert User().name == "root"
69+
70+
71+
@pytest.mark.testinfra_hosts("ansible://debian_jessie")
72+
def test_ansible_hosts_expand(TestinfraBackend):
73+
from testinfra.backend.ansible import AnsibleBackend
74+
75+
def get_hosts(spec):
76+
return AnsibleBackend.get_hosts(
77+
spec, ansible_inventory=TestinfraBackend.ansible_inventory)
78+
assert get_hosts(["all"]) == ["debian_jessie"]
79+
assert get_hosts(["testgroup"]) == ["debian_jessie"]
80+
assert get_hosts(["*ia*jess*"]) == ["debian_jessie"]

testinfra/test/test_modules.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def test_ansible_module(TestinfraBackend, Ansible):
263263
version = int(ansible.__version__.split(".", 1)[0])
264264
setup = Ansible("setup")["ansible_facts"]
265265
assert setup["ansible_lsb"]["codename"] == "jessie"
266-
passwd = Ansible("file", "path=/etc/passwd")
266+
passwd = Ansible("file", "path=/etc/passwd state=file")
267267
assert passwd["changed"] is False
268268
assert passwd["gid"] == 0
269269
assert passwd["group"] == "root"
@@ -275,8 +275,11 @@ def test_ansible_module(TestinfraBackend, Ansible):
275275
assert passwd["uid"] == 0
276276

277277
variables = Ansible.get_variables()
278+
assert variables["myvar"] == "foo"
279+
assert variables["myhostvar"] == "bar"
280+
assert variables["mygroupvar"] == "qux"
278281
assert variables["inventory_hostname"] == "debian_jessie"
279-
assert variables["group_names"] == ["ungrouped"]
282+
assert variables["group_names"] == ["testgroup"]
280283

281284
# test errors reporting
282285
with pytest.raises(Ansible.AnsibleException) as excinfo:

testinfra/utils/ansible_runner.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,28 @@
2525
else:
2626
_has_ansible = True
2727
_ansible_major_version = int(ansible.__version__.split(".", 1)[0])
28+
import ansible.constants
2829
if _ansible_major_version == 1:
2930
import ansible.inventory
3031
import ansible.runner
32+
import ansible.utils
3133
elif _ansible_major_version == 2:
34+
import ansible.cli
3235
import ansible.executor.task_queue_manager
3336
import ansible.inventory
3437
import ansible.parsing.dataloader
3538
import ansible.playbook.play
3639
import ansible.plugins.callback
40+
import ansible.utils.vars
3741
import ansible.vars
3842

3943

44+
def _reload_constants():
45+
# Reload defaults that can depend on environment variables and
46+
# current working directory
47+
reload(ansible.constants)
48+
49+
4050
class AnsibleRunnerBase(object):
4151

4252
def __init__(self, host_list=None):
@@ -71,7 +81,13 @@ class AnsibleRunnerV1(AnsibleRunnerBase):
7181

7282
def __init__(self, host_list=None):
7383
super(AnsibleRunnerV1, self).__init__(host_list)
74-
self.inventory = ansible.inventory.Inventory(self.host_list)
84+
_reload_constants()
85+
self.vault_pass = ansible.utils.read_vault_file(
86+
ansible.constants.DEFAULT_VAULT_PASSWORD_FILE)
87+
kwargs = {"vault_password": self.vault_pass}
88+
if self.host_list is not None:
89+
kwargs["host_list"] = host_list
90+
self.inventory = ansible.inventory.Inventory(**kwargs)
7591

7692
def get_hosts(self, pattern=None):
7793
return [
@@ -91,6 +107,8 @@ def run(self, host, module_name, module_args=None, **kwargs):
91107
result = ansible.runner.Runner(
92108
pattern=host,
93109
module_name=module_name,
110+
vault_pass=self.vault_pass,
111+
inventory=self.inventory,
94112
**kwargs).run()
95113
if host not in result["contacted"]:
96114
raise RuntimeError("Unexpected error: {}".format(result))
@@ -100,21 +118,6 @@ def run(self, host, module_name, module_args=None, **kwargs):
100118
return result["contacted"][host]
101119

102120

103-
class Options(object):
104-
105-
def __init__(self, **kwargs):
106-
self.connection = "smart"
107-
for attr in (
108-
"module_path", "forks", "remote_user", "private_key_file",
109-
"ssh_common_args", "ssh_extra_args", "sftp_extra_args",
110-
"scp_extra_args", "become", "become_method", "become_user",
111-
"verbosity",
112-
):
113-
setattr(self, attr, None)
114-
self.check = kwargs.get("check", False)
115-
super(Options, self).__init__()
116-
117-
118121
if _has_ansible and _ansible_major_version == 2:
119122
class Callback(ansible.plugins.callback.CallbackBase):
120123

@@ -147,12 +150,31 @@ class AnsibleRunnerV2(AnsibleRunnerBase):
147150

148151
def __init__(self, host_list=None):
149152
super(AnsibleRunnerV2, self).__init__(host_list)
153+
_reload_constants()
150154
self.variable_manager = ansible.vars.VariableManager()
155+
self.options = ansible.cli.CLI(None).base_parser(
156+
connect_opts=True,
157+
meta_opts=True,
158+
runas_opts=True,
159+
subset_opts=True,
160+
check_opts=True,
161+
inventory_opts=True,
162+
runtask_opts=True,
163+
vault_opts=True,
164+
fork_opts=True,
165+
module_opts=True,
166+
).parse_args([])[0]
167+
self.options.connection = "smart"
151168
self.loader = ansible.parsing.dataloader.DataLoader()
169+
if self.options.vault_password_file:
170+
vault_pass = ansible.cli.CLI.read_vault_password_file(
171+
self.options.vault_password_file, loader=self.loader)
172+
self.loader.set_vault_password(vault_pass)
173+
152174
self.inventory = ansible.inventory.Inventory(
153175
loader=self.loader,
154176
variable_manager=self.variable_manager,
155-
host_list=host_list,
177+
host_list=host_list or self.options.inventory,
156178
)
157179
self.variable_manager.set_inventory(self.inventory)
158180

@@ -163,9 +185,12 @@ def get_hosts(self, pattern=None):
163185
]
164186

165187
def get_variables(self, host):
166-
return self.inventory.get_vars(host)
188+
host = self.inventory.get_host(host)
189+
return ansible.utils.vars.combine_vars(
190+
host.get_group_vars(), host.get_vars())
167191

168192
def run(self, host, module_name, module_args=None, **kwargs):
193+
self.options.check = kwargs.get("check", False)
169194
action = {"module": module_name}
170195
if module_args is not None:
171196
if module_name in ("command", "shell"):
@@ -180,14 +205,13 @@ def run(self, host, module_name, module_args=None, **kwargs):
180205
}],
181206
}, variable_manager=self.variable_manager, loader=self.loader)
182207
tqm = None
183-
options = Options(**kwargs)
184208
callback = Callback()
185209
try:
186210
tqm = ansible.executor.task_queue_manager.TaskQueueManager(
187211
inventory=self.inventory,
188212
variable_manager=self.variable_manager,
189213
loader=self.loader,
190-
options=options,
214+
options=self.options,
191215
passwords=None,
192216
stdout_callback=callback,
193217
)
@@ -208,16 +232,3 @@ def run(self, host, module_name, module_args=None, **kwargs):
208232
raise NotImplementedError(
209233
"Unhandled ansible version " + ansible.__version__
210234
)
211-
212-
213-
def get_hosts(host_list=None, pattern=None):
214-
return AnsibleRunner(host_list).get_hosts(pattern)
215-
216-
217-
def run(host, module_name, module_args=None, host_list=None, **kwargs):
218-
return AnsibleRunner(host_list).run(
219-
host, module_name, module_args, **kwargs)
220-
221-
222-
def get_variables(host, host_list=None):
223-
return AnsibleRunner(host_list).get_variables(host)

0 commit comments

Comments
 (0)