Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions probing/data_former.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self.shuffle = shuffle
self.data_path = get_probe_task_path(probe_task, data_path)

self.samples, self.unique_labels = self.form_data(sep=sep)
self.samples, self.unique_labels, self.num_words = self.form_data(sep=sep)

def __len__(self):
return len(self.samples)
Expand All @@ -48,8 +48,9 @@ def form_data(
samples_dict = defaultdict(list)
unique_labels = set()
dataset = pd.read_csv(self.data_path, sep=sep, header=None, dtype=str)
for _, (stage, label, text) in dataset.iterrows():
samples_dict[stage].append((text, label))
for _, (stage, label, text, word_indices) in dataset.iterrows():
num_words = len(word_indices)
samples_dict[stage].append((text, label, word_indices))
unique_labels.add(label)

if self.shuffle:
Expand All @@ -58,7 +59,7 @@ def form_data(
}
else:
samples_dict = {k: np.array(v) for k, v in samples_dict.items()}
return samples_dict, unique_labels
return samples_dict, unique_labels, num_words


class EncodedVectorFormer(Dataset):
Expand Down
11 changes: 6 additions & 5 deletions probing/ud_filter/filtering_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, shuffle: bool = True):
self.classes: Dict[
str, Tuple[Dict[str, Dict[str, Any]], Dict[Tuple[str, str], Dict[str, Any]]]
] = {}
self.probing_dict: Dict[str, List[str]] = {}
self.probing_dict: Dict[str, List[Tuple[str, List[int]]]] = {}
self.parts_data: Dict[str, List[List[str]]] = {}

def upload_files(
Expand Down Expand Up @@ -91,10 +91,11 @@ def _filter_conllu(self, class_label: str) -> Tuple[List[str], List[str]]:
for sentence in self.sentences:
sf = SentenceFilter(sentence)
tokenized_sentence = " ".join(wordpunct_tokenize(sentence.metadata["text"]))
if sf.filter_sentence(node_pattern, constraints):
matching.append(tokenized_sentence)
filter_result = sf.filter_sentence(node_pattern, constraints)
if filter_result is not False:
matching.append((tokenized_sentence, filter_result))
else:
not_matching.append(tokenized_sentence)
not_matching.append((tokenized_sentence, ()))
return matching, not_matching

def filter_and_convert(
Expand Down Expand Up @@ -128,7 +129,7 @@ def filter_and_convert(
matching, not_matching = self._filter_conllu(label)
self.probing_dict[label] = matching
if len(self.classes) == 1:
self.probing_dict["not_" + list(self.classes.keys())[0]] = not_matching
self.probing_dict["not_" + label] = not_matching
self.probing_dict = delete_duplicates(self.probing_dict)

self.parts_data = subsamples_split(
Expand Down
4 changes: 2 additions & 2 deletions probing/ud_filter/sentence_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def find_isomorphism(self) -> bool:
k: {edges[i]} for i, k in enumerate(self.possible_token_pairs)
}
self.nodes_tokens = {
np[i]: [list(self.possible_token_pairs[np])[0][i]]
np[i]: list(self.possible_token_pairs[np])[0][i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

почему ты удалил скобки? Начальный тип self.nodes_tokens: Dict[str, List[int]], оставляя так, все ломается

for np in self.possible_token_pairs
for i in range(2)
}
Expand Down Expand Up @@ -243,6 +243,6 @@ def filter_sentence(
else:
self.sent_deprels = self.all_deprels()
if self.match_constraints():
return True
return tuple(self.nodes_tokens.values())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

функция возвращает bool, теперь ты возвращает tuple - это ломает тайпинги не только тут, но и те, которые зависили от этой функции

else:
return False
26 changes: 14 additions & 12 deletions probing/ud_filter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def subsamples_split(
if not probing_data:
raise Exception("All classes have less sentences than the number of classes")
parts = {}
data, labels = map(np.array, zip(*probing_data))
data, labels = map(list, zip(*probing_data))
X_train, X_test, y_train, y_test = train_test_split(
data,
labels,
Expand All @@ -58,19 +58,20 @@ def subsamples_split(
shuffle=shuffle,
random_state=random_seed,
)

if len(partition) == 2:
parts = {split[0]: [X_train, y_train], split[1]: [X_test, y_test]}
else:
filtered_labels = filter_labels_after_split(y_test)
if len(filtered_labels) >= 2:
X_train = X_train[np.isin(y_train, filtered_labels)]
y_train = y_train[np.isin(y_train, filtered_labels)]
X_test = X_test[np.isin(y_test, filtered_labels)]
y_test = y_test[np.isin(y_test, filtered_labels)]
train_mask = np.isin(y_train, filtered_labels)
X_train = [X_train[i] for i in range(len(train_mask)) if train_mask[i]]
y_train = [y_train[i] for i in range(len(train_mask)) if train_mask[i]]
test_mask = np.isin(y_test, filtered_labels)
X_test = [X_test[i] for i in range(len(test_mask)) if test_mask[i]]
y_test = [y_test[i] for i in range(len(test_mask)) if test_mask[i]]

val_size = partition[1] / (1 - partition[0])
if y_test.size != 0:
if len(y_test) != 0:
X_val, X_test, y_val, y_test = train_test_split(
X_test,
y_test,
Expand Down Expand Up @@ -118,8 +119,9 @@ def writer(
with open(result_path, "w", encoding="utf-8") as newf:
my_writer = csv.writer(newf, delimiter="\t", lineterminator="\n")
for part in partition_sets:
for sentence, value in zip(*partition_sets[part]):
my_writer.writerow([part, value, sentence])
for sentence_and_ids, value in zip(*partition_sets[part]):
sentence, ids = sentence_and_ids
my_writer.writerow([part, value, sentence, ",".join([str(x) for x in ids])])
return result_path


Expand Down Expand Up @@ -150,11 +152,11 @@ def determine_ud_savepath(
def delete_duplicates(probing_dict: Dict[str, List[str]]) -> Dict[str, List[str]]:
"""Deletes sentences with more than one different classes of node_pattern found"""

all_sent = [s for cl_sent in probing_dict.values() for s in cl_sent]
duplicates = [item for item, count in Counter(all_sent).items() if count > 1]
all_sent = [sent for cl_sent in probing_dict.values() for sent, inds in cl_sent]
duplicates = {item for item, count in Counter(all_sent).items() if count > 1}
new_probing_dict = {}
for cl in probing_dict:
new_probing_dict[cl] = [s for s in probing_dict[cl] if s not in duplicates]
new_probing_dict[cl] = [(sent, ind) for sent, ind in probing_dict[cl] if sent not in duplicates]
return new_probing_dict


Expand Down
11 changes: 7 additions & 4 deletions probing/ud_parser/ud_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def writer(
with open(result_path, "w", encoding="utf-8") as newf:
my_writer = csv.writer(newf, delimiter="\t", lineterminator="\n")
for part in partition_sets:
for sentence, value in zip(*partition_sets[part]):
my_writer.writerow([part, value, sentence])
for sentence_and_id, value in zip(*partition_sets[part]):
sentence, id = sentence_and_id
my_writer.writerow([part, value, sentence, id])
return result_path

def find_category_token(
Expand Down Expand Up @@ -134,15 +135,17 @@ def classify(
)
):
value = category_token["feats"][category]
probing_data[value].append(s_text)
token_id = category_token["id"] - 1
probing_data[value].append((s_text, token_id))
elif self.sorting == "by_pos_and_deprel":
pos, deprel = subcategory.split("_")
if (
category_token["upos"] == pos
and category_token["deprel"] == deprel
):
value = category_token["feats"][category]
probing_data[value].append(s_text)
token_id = category_token["id"] - 1
probing_data[value].append((s_text, token_id))
return probing_data

def filter_labels_after_split(self, labels: List[Any]) -> List[Any]:
Expand Down