-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmapping_relations.py
More file actions
80 lines (63 loc) · 2.79 KB
/
mapping_relations.py
File metadata and controls
80 lines (63 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import dictionaries
import argparse
from pathlib import Path
import lxml.etree as ET
import os
def get_key_by_list_element(d, target_element):
for k, v_list in d.items():
if target_element in v_list:
return k
def iterate_files(input_path, output_path, rst_type):
for file_path in input_path.glob("*.rs3"):
w = file_path.open()
rf = w.read()
tree = parseXML(rf, rst_type)
# convert tree to an element tree and write into file
element_tree = ET.ElementTree(tree)
output_fp = output_path / os.path.basename(file_path)
output_path.mkdir(parents=True, exist_ok=True)
element_tree.write(output_fp, pretty_print=True)
def parseXML(xml, rst_type):
tree = ET.fromstring(xml)
elements = tree.xpath("/rst/header/relations/rel | /rst/body/segment | /rst/body/group")
for relname in elements:
if relname.tag == "rel":
element_to_find = relname.attrib["name"].lower()
if rst_type == "GUM":
key = get_key_by_list_element(dictionaries.unsc2gum, element_to_find)
print("UNSC: ", element_to_find, " GUM: ", key)
relname.attrib["name"] = key
assert key is not None
elif rst_type == "RST-DT":
key = get_key_by_list_element(dictionaries.unsc2rstdt, element_to_find)
print("UNSC: ", element_to_find, " RST-DT: ", key)
relname.attrib["name"] = key
assert key is not None
elif relname.tag == "segment" or relname.tag == "group":
try:
element_to_find = relname.attrib["relname"].lower()
if rst_type == "GUM":
key = get_key_by_list_element(dictionaries.unsc2gum, element_to_find)
# print("UNSC: ", element_to_find, " GUM: ", key)
relname.attrib["relname"] = key
assert key is not None
elif rst_type == "RST-DT":
key = get_key_by_list_element(dictionaries.unsc2rstdt, element_to_find)
print("UNSC: ", element_to_find, " RST-DT: ", key)
relname.attrib["relname"] = key
assert key is not None
except:
continue
return tree
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--rst_type", choices=["GUM", "RST-DT"], required=True)
parser.add_argument("--path", required=True) # input path
args = parser.parse_args()
input_path = Path(args.path)
rst_type = args.rst_type
output_path = Path("./dataset/gold_translated_dataset/07_rst_" + rst_type+ "-relations")
iterate_files(input_path, output_path, rst_type)
print("Saved in: ", output_path)
if __name__=="__main__":
main()