Skip to content

Commit f0de9d4

Browse files
authored
Merge pull request #2494 from eerovaher/simbad-refractor
Simplify `simbad` code
2 parents 1dea6ac + 35b13ec commit f0de9d4

File tree

4 files changed

+55
-84
lines changed

4 files changed

+55
-84
lines changed

CHANGES.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ oac
8585
- Fix bug in parsing events that contain html tags (e.g. in their alias
8686
field). [#2423]
8787

88+
simbad
89+
^^^^^^
90+
91+
- It is now possible to specify multiple coordinates together with a single
92+
radius as a string in ``query_region()`` and ``query_region_async()``.
93+
[#2494]
94+
8895
svo_fps
8996
^^^^^^^
9097

astroquery/simbad/core.py

Lines changed: 39 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
Simbad query class for accessing the Simbad Service
44
"""
55

6-
import copy
76
import re
87
import requests
98
import json
@@ -12,6 +11,7 @@
1211
from io import BytesIO
1312
import warnings
1413
import astropy.units as u
14+
from astropy.utils import isiterable
1515
from astropy.utils.data import get_pkg_data_filename
1616
import astropy.coordinates as coord
1717
from astropy.table import Table
@@ -202,25 +202,20 @@ class SimbadBibcodeResult(SimbadResult):
202202
"""Bibliography-type Simbad result"""
203203
@property
204204
def table(self):
205-
bibcode_match = bibcode_regex.search(self.script)
206-
splitter = bibcode_match.group(2)
207-
ref_list = [splitter + ref for ref in self.data.split(splitter)][1:]
208-
max_len = max([len(r) for r in ref_list])
209-
table = Table(names=['References'], dtype=['U%i' % max_len])
210-
for ref in ref_list:
211-
table.add_row([ref])
212-
return table
205+
splitter = bibcode_regex.search(self.script).group(2)
206+
ref_list = [[splitter + ref] for ref in self.data.split(splitter)[1:]]
207+
max_len = max(len(r[0]) for r in ref_list)
208+
return Table(rows=ref_list, names=['References'], dtype=[f"U{max_len}"])
213209

214210

215211
class SimbadObjectIDsResult(SimbadResult):
216212
"""Object identifier list Simbad result"""
217213
@property
218214
def table(self):
219-
max_len = max([len(i) for i in self.data.splitlines()])
220-
table = Table(names=['ID'], dtype=['S%i' % max_len])
221-
for id in self.data.splitlines():
222-
table.add_row([id.strip()])
223-
return table
215+
split_lines = self.data.splitlines()
216+
ids = [[id.strip()] for id in split_lines]
217+
max_len = max(map(len, split_lines))
218+
return Table(rows=ids, names=['ID'], dtype=[f"S{max_len}"])
224219

225220

226221
class SimbadBaseQuery(BaseQuery):
@@ -281,8 +276,6 @@ class SimbadClass(SimbadBaseQuery):
281276
),
282277
'[^0-9]': 'Any (one) character not in the list.'}
283278

284-
_ORDERED_WILDCARDS = ['*', '?', '[abc]', '[^0-9]']
285-
286279
# query around not included since this is a subcase of query_region
287280
_function_to_command = {
288281
'query_object_async': 'query id',
@@ -303,7 +296,7 @@ class SimbadClass(SimbadBaseQuery):
303296

304297
def __init__(self):
305298
super().__init__()
306-
self._VOTABLE_FIELDS = copy.copy(self._VOTABLE_FIELDS)
299+
self._VOTABLE_FIELDS = self._VOTABLE_FIELDS.copy()
307300

308301
def list_wildcards(self):
309302
"""
@@ -323,10 +316,7 @@ def list_wildcards(self):
323316
[abc] : Exactly one character taken in the list.
324317
Can also be defined by a range of characters: [A-Z]
325318
"""
326-
for key in self._ORDERED_WILDCARDS:
327-
print("{key} : {value}\n".format(key=key,
328-
value=self.WILDCARDS[key]))
329-
return
319+
print("\n\n".join(f"{k} : {v}" for k, v in self.WILDCARDS.items()))
330320

331321
def list_votable_fields(self):
332322
"""
@@ -353,8 +343,8 @@ def list_votable_fields(self):
353343
fields_dict = json.load(f)
354344

355345
print("Available VOTABLE fields:\n")
356-
for field in list(sorted(fields_dict.keys())):
357-
print("{}".format(field))
346+
for field in sorted(fields_dict.keys()):
347+
print(str(field))
358348
print("For more information on a field:\n"
359349
"Simbad.get_field_description ('field_name') \n"
360350
"Currently active VOTABLE fields:\n {0}"
@@ -415,10 +405,7 @@ def add_votable_fields(self, *args):
415405
os.path.join('data', 'votable_fields_dict.json'))
416406

417407
with open(dict_file, "r") as f:
418-
fields_dict = json.load(f)
419-
fields_dict = dict(
420-
((strip_field(ff), fields_dict[ff])
421-
for ff in fields_dict))
408+
fields_dict = {strip_field(k): v for k, v in json.load(f).items()}
422409

423410
for field in args:
424411
sf = strip_field(field)
@@ -427,34 +414,30 @@ def add_votable_fields(self, *args):
427414
else:
428415
self._VOTABLE_FIELDS.append(field)
429416

430-
def remove_votable_fields(self, *args, **kwargs):
417+
def remove_votable_fields(self, *args, strip_params=False):
431418
"""
432419
Removes the specified field names from ``SimbadClass._VOTABLE_FIELDS``
433420
434421
Parameters
435422
----------
436423
list of field_names to be removed
437-
strip_params: bool
424+
strip_params: bool, optional
438425
If true, strip the specified keywords before removing them:
439426
e.g., ra(foo) would remove ra(bar) if this is True
440427
"""
441-
strip_params = kwargs.pop('strip_params', False)
442-
443428
if strip_params:
444-
sargs = [strip_field(a) for a in args]
429+
sargs = {strip_field(a) for a in args}
445430
sfields = [strip_field(a) for a in self._VOTABLE_FIELDS]
446431
else:
447-
sargs = args
432+
sargs = set(args)
448433
sfields = self._VOTABLE_FIELDS
449-
absent_fields = set(sargs) - set(sfields)
450-
451-
for b, f in list(zip(sfields, self._VOTABLE_FIELDS)):
452-
if b in sargs:
453-
self._VOTABLE_FIELDS.remove(f)
454434

455-
for field in absent_fields:
435+
for field in sargs.difference(sfields):
456436
warnings.warn("{field}: this field is not set".format(field=field))
457437

438+
zipped_fields = zip(sfields, self._VOTABLE_FIELDS)
439+
self._VOTABLE_FIELDS = [f for b, f in zipped_fields if b not in sargs]
440+
458441
# check if all fields are removed
459442
if not self._VOTABLE_FIELDS:
460443
warnings.warn("All fields have been removed. "
@@ -678,8 +661,6 @@ def query_region_async(self, coordinates, radius=2*u.arcmin,
678661

679662
# handle the vector case
680663
if isinstance(ra, list):
681-
vector = True
682-
683664
if len(ra) > 10000:
684665
warnings.warn("For very large queries, you may receive a "
685666
"timeout error. SIMBAD suggests splitting "
@@ -689,21 +670,19 @@ def query_region_async(self, coordinates, radius=2*u.arcmin,
689670
if len(set(frame)) > 1:
690671
raise ValueError("Coordinates have different frames")
691672
else:
692-
frame = set(frame).pop()
673+
frame = frame[0]
693674

694-
if vector and _has_length(radius) and len(radius) == len(ra):
695-
# all good, continue
696-
pass
697-
elif vector and _has_length(radius) and len(radius) != len(ra):
698-
raise ValueError("Mismatch between radii and coordinates")
699-
elif vector and not _has_length(radius):
675+
# `radius` as `str` is iterable, but contains only one value.
676+
if isiterable(radius) and not isinstance(radius, str):
677+
if len(radius) != len(ra):
678+
raise ValueError("Mismatch between radii and coordinates")
679+
else:
700680
radius = [_parse_radius(radius)] * len(ra)
701681

702-
if vector:
703-
query_str = "\n".join([base_query_str
704-
.format(ra=ra_, dec=dec_, rad=rad_,
705-
frame=frame, equinox=equinox)
706-
for ra_, dec_, rad_ in zip(ra, dec, radius)])
682+
query_str = "\n".join(base_query_str
683+
.format(ra=ra_, dec=dec_, rad=rad_,
684+
frame=frame, equinox=equinox)
685+
for ra_, dec_, rad_ in zip(ra, dec, radius))
707686

708687
else:
709688
radius = _parse_radius(radius)
@@ -956,20 +935,13 @@ def query_objectids_async(self, object_name, cache=True,
956935
return response
957936

958937
def _get_query_header(self, get_raw=False):
959-
votable_fields = ','.join(self.get_votable_fields())
960938
# if get_raw is set then don't fetch as votable
961939
if get_raw:
962940
return ""
963-
votable_def = "votable {" + votable_fields + "}"
964-
votable_open = "votable open"
965-
return "\n".join([votable_def, votable_open])
941+
return "votable {" + ','.join(self.get_votable_fields()) + "}\nvotable open"
966942

967943
def _get_query_footer(self, get_raw=False):
968-
if get_raw:
969-
return ""
970-
votable_close = "votable close"
971-
972-
return votable_close
944+
return "" if get_raw else "votable close"
973945

974946
@validate_epoch_decorator
975947
@validate_equinox_decorator
@@ -1004,25 +976,19 @@ def _args_to_payload(self, *args, **kwargs):
1004976
kwargs['equi'] = kwargs['equinox']
1005977
del kwargs['equinox']
1006978
# remove default None from kwargs
1007-
# be compatible with python3
1008-
for key in list(kwargs):
1009-
if not kwargs[key]:
1010-
del kwargs[key]
979+
kwargs = {key: value for key, value in kwargs.items() if value is not None}
1011980
# join in the order specified otherwise results in error
1012981
all_keys = ['radius', 'frame', 'equi', 'epoch']
1013982
present_keys = [key for key in all_keys if key in kwargs]
1014983
if caller == 'query_criteria_async':
1015-
for k in kwargs:
1016-
present_keys.append(k)
984+
present_keys.extend(kwargs)
1017985
# need ampersands to join args
1018986
args_str = '&'.join([str(val) for val in args])
1019-
if len(args) > 0 and len(present_keys) > 0:
987+
if args and present_keys:
1020988
args_str += " & "
1021989
else:
1022990
args_str = ' '.join([str(val) for val in args])
1023-
kwargs_str = ' '.join("{key}={value}".format(key=key,
1024-
value=kwargs[key])
1025-
for key in present_keys)
991+
kwargs_str = ' '.join(f"{key}={kwargs[key]}" for key in present_keys)
1026992

1027993
# For the record, I feel dirty for writing this wildcard-case hack.
1028994
# This entire function should be refactored when someone has time.
@@ -1081,17 +1047,8 @@ def _parse_coordinates(coordinates):
10811047
raise ValueError("Coordinates not specified correctly")
10821048

10831049

1084-
def _has_length(x):
1085-
# some objects have '__len__' attributes but have no len()
1086-
try:
1087-
len(x)
1088-
return True
1089-
except (TypeError, AttributeError):
1090-
return False
1091-
1092-
10931050
def _get_frame_coords(coordinates):
1094-
if _has_length(coordinates):
1051+
if isiterable(coordinates):
10951052
# deal with vectors differently
10961053
parsed = [_get_frame_coords(cc) for cc in coordinates]
10971054
return ([ra for ra, dec, frame in parsed],

astroquery/simbad/tests/test_simbad.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def test_query_catalog(patch_post):
243243
(FK4_COORDS, '5d0m0s', 2000.0, 'J2000'),
244244
(FK5_COORDS, None, 2000.0, 'J2000'),
245245
(multicoords, 0.5*u.arcsec, 2000.0, 'J2000'),
246+
(multicoords, "0.5s", 2000.0, 'J2000'),
246247
])
247248
def test_query_region_async(patch_post, coordinates, radius, equinox, epoch):
248249
response1 = simbad.core.Simbad.query_region_async(
@@ -283,6 +284,11 @@ def test_query_region_radius_error(patch_post, coordinates, radius,
283284
coordinates, radius=radius, equinox=equinox, epoch=epoch)
284285

285286

287+
def test_query_region_coord_radius_mismatch():
288+
with pytest.raises(ValueError, match="^Mismatch between radii and coordinates$"):
289+
simbad.SimbadClass().query_region(multicoords, radius=[1, 2, 3] * u.deg)
290+
291+
286292
@pytest.mark.parametrize(('coordinates', 'radius', 'equinox', 'epoch'),
287293
[(ICRS_COORDS, "0d", 2000.0, 'J2000'),
288294
(GALACTIC_COORDS, 1.0 * u.marcsec, 2000.0, 'J2000')

astroquery/simbad/tests/test_simbad_remote.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ def test_query_region_async(self, temp_dir):
107107

108108
assert response is not None
109109

110-
def test_query_region_async_vector(self, temp_dir):
110+
@pytest.mark.parametrize("radius", (0.5 * u.arcsec, "0.5s"))
111+
def test_query_region_async_vector(self, temp_dir, radius):
111112
simbad = Simbad()
112113
simbad.cache_location = temp_dir
113-
response1 = simbad.query_region_async(multicoords, radius=0.5*u.arcsec)
114+
response1 = simbad.query_region_async(multicoords, radius=radius)
114115
assert response1.request.body == 'script=votable+%7Bmain_id%2Ccoordinates%7D%0Avotable+open%0Aquery+coo+5%3A35%3A17.3+-80%3A52%3A00+radius%3D0.5s+frame%3DICRS+equi%3D2000.0%0Aquery+coo+17%3A47%3A20.4+-28%3A23%3A07.008+radius%3D0.5s+frame%3DICRS+equi%3D2000.0%0Avotable+close' # noqa
115116

116117
def test_query_region(self, temp_dir):

0 commit comments

Comments
 (0)