|
|
|
|
|
|
|
|
|
|
|
import os, sys |
|
|
|
|
|
import datasets |
|
|
import transformers |
|
|
|
|
|
import disrpt_io |
|
|
import utils |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LANGUAGES = [] |
|
|
|
|
|
def read_dataset( input_path, output_path, config, add_lang_token=True,add_frame_token=True,lang_token="",frame_token="" ): |
|
|
''' |
|
|
- Read the file in input_path |
|
|
- Return a Dataset corresponding to the file |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
input_path : str |
|
|
Path to the dataset |
|
|
output_path : str |
|
|
Path to an output directory that can be used to write new split files |
|
|
tokenizer : AutoTokenizer |
|
|
Tokenizer corresponding the checkpoint model |
|
|
add_lang_token : bool |
|
|
If True, add a special language token at the beginning of each sequence |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Dataset |
|
|
Contain Dataset built from train_path and dev_path for train mode, |
|
|
only dev / test pasth else |
|
|
Tokenizer |
|
|
The tokenizer used for the dataset |
|
|
''' |
|
|
model_checkpoint = config["model_checkpoint"] |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained( model_checkpoint ) |
|
|
|
|
|
dataset = DatasetSeq( input_path, output_path, config, tokenizer, add_lang_token=add_lang_token,add_frame_token=add_frame_token,lang_token=lang_token,frame_token=frame_token ) |
|
|
dataset.read_and_tokenize() |
|
|
|
|
|
LABEL_NAMES_BIO = retrieve_bio_labels( dataset ) |
|
|
dataset.set_label_names_bio(LABEL_NAMES_BIO) |
|
|
return dataset, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetDisc( ): |
|
|
def __init__(self, annotations_file, output_path, config, tokenizer, dset=None ): |
|
|
""" |
|
|
Here we save the location of our input file, |
|
|
load the data, i.e. retrieve the list of texts and associated labels, |
|
|
build the vocabulary if none is given, |
|
|
and define the pipelines used to prepare the data |
|
|
""" |
|
|
self.annotations_file = annotations_file |
|
|
if isinstance(annotations_file, str) and not os.path.isfile(annotations_file): |
|
|
print("this is a string dataset") |
|
|
self.basename = "input" |
|
|
else: |
|
|
self.basename = os.path.basename( self.annotations_file ) |
|
|
self.dset = self.basename.split(".")[2].split('_')[1] |
|
|
self.corpus_name = self.basename.split('_')[0] |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.config = config |
|
|
|
|
|
self.output_path = output_path |
|
|
|
|
|
|
|
|
self.mode = config["type"] |
|
|
self.task = config["task"] |
|
|
self.trace = config["trace"] |
|
|
self.tok_config = config["tok_config"] |
|
|
self.sent_spliter = config["sent_spliter"] |
|
|
|
|
|
|
|
|
self.id2label, self.label2id = {}, {} |
|
|
|
|
|
|
|
|
self.corpus = init_corpus( self.task ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_and_tokenize( self ): |
|
|
print("\n-- READ FROM FILE:", self.annotations_file ) |
|
|
try: |
|
|
self.read_annotations( ) |
|
|
except Exception as err: |
|
|
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
self.set_labels( ) |
|
|
print( "self.label2id", self.label2id ) |
|
|
|
|
|
|
|
|
self.tokenize_dataset() |
|
|
if self.trace: |
|
|
if self.dset: |
|
|
print( "\n-- FINISHED READING", self.dset, "PRINTING TRACE --") |
|
|
self.print_trace() |
|
|
|
|
|
def tokenize_datasets( self ): |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
def set_labels( self ): |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
def read_annotations( self ): |
|
|
''' |
|
|
Generate a Corpus object based on the input_file. |
|
|
Since .tok files are not segmented into sentences, a sentence splitter |
|
|
is used (here, ersatz) |
|
|
''' |
|
|
if os.path.isfile(self.annotations_file): |
|
|
self.corpus.from_file(self.annotations_file) |
|
|
lang = os.path.basename(self.annotations_file).split(".")[0] |
|
|
frame = os.path.basename(self.annotations_file).split(".")[1] |
|
|
base = os.path.basename(self.annotations_file) |
|
|
else: |
|
|
|
|
|
src = self.mode if self.mode in ["tok", "conllu", "split"] else "conllu" |
|
|
self.corpus.from_string(self.annotations_file,src=src) |
|
|
lang = self.lang_token |
|
|
frame = self.frame_token |
|
|
base = "input.text" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for doc in self.corpus.docs: |
|
|
doc.lang = lang |
|
|
doc.frame = frame |
|
|
|
|
|
|
|
|
if self.mode == 'tok': |
|
|
kwargs={} |
|
|
from wtpsplit import SaT |
|
|
sat_version="sat-3l" |
|
|
if "sat_model" in self.config: |
|
|
sat_version=self.config["sat_model"] |
|
|
|
|
|
sat_model = SaT(sat_version) |
|
|
kwargs["sat_model"] = sat_model |
|
|
self.corpus.sentence_split(model = self.sent_spliter, lang="default-multilingual",sat_model=sat_model) |
|
|
|
|
|
parts = base.split(".")[:-1] |
|
|
split_filename = ".".join(parts) + ".split" |
|
|
split_file = os.path.join(self.output_path, split_filename) |
|
|
self.corpus.format(file=split_file) |
|
|
|
|
|
|
|
|
def print_trace( self ): |
|
|
print( "\n| Annotation_file: ", self.annotations_file ) |
|
|
print( '| Output_path:', self.output_path ) |
|
|
print( '| Nb_of_instances:', len(self.dataset), "(", len(self.dataset['labels']), ")" ) |
|
|
|
|
|
|
|
|
def print_stats( self ): |
|
|
print( "| Annotation_file: ", self.annotations_file ) |
|
|
if self.dset: print( "| Data_split: ", self.dset ) |
|
|
print( "| Task: ", self.task ) |
|
|
print( "| Lang: ", self.lang ) |
|
|
print( "| Mode: ", self.mode ) |
|
|
print( "| Label_names: ", self.LABEL_NAMES) |
|
|
|
|
|
print( "| Number_of_instances: ", len(self.dataset) ) |
|
|
|
|
|
|
|
|
|
|
|
class DatasetSeq(DatasetDisc): |
|
|
def __init__( self, annotations_file, output_path, config, tokenizer, add_lang_token=True, add_frame_token=True, |
|
|
dset=None,lang_token="",frame_token="" ): |
|
|
""" |
|
|
Class for tasks corresponding to sequence labeling problem |
|
|
(seg, conn). |
|
|
Here we save the location of our input file, |
|
|
load the data, i.e. retrieve the list of texts and associated |
|
|
labels, |
|
|
build the vocabulary if none is given, |
|
|
and define the pipelines used to prepare the data """ |
|
|
DatasetDisc.__init__( self, annotations_file, output_path, config, |
|
|
tokenizer ) |
|
|
self.add_lang_token = add_lang_token |
|
|
self.add_frame_token=add_frame_token |
|
|
self.lang_token = lang_token |
|
|
self.frame_token=frame_token |
|
|
|
|
|
if self.mode == 'tok' and self.output_path == None: |
|
|
self.output_path = os.path.dirname( self.annotations_file ) |
|
|
self.output_path = os.path.join( self.output_path, |
|
|
self.basename.replace("."+self.mode, ".split") ) |
|
|
|
|
|
self.sent_spliter = None |
|
|
if "sent_spliter" in self.config: |
|
|
self.sent_spliter = self.config["sent_spliter"] |
|
|
|
|
|
self.LABEL_NAMES_BIO = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_dataset( self ): |
|
|
|
|
|
if self.trace: |
|
|
print(f"\n-- Creating dataset from generator (add_lang_token={self.add_lang_token})") |
|
|
self.dataset = datasets.Dataset.from_generator( |
|
|
gen, |
|
|
gen_kwargs={"corpus": self.corpus, "label2id": self.label2id, "mode": self.mode, "add_lang_token": self.add_lang_token,"add_frame_token":self.add_frame_token}, |
|
|
) |
|
|
if self.trace: |
|
|
print( self.dataset[0]) |
|
|
|
|
|
|
|
|
self.all_word_ids = [] |
|
|
|
|
|
if self.trace: |
|
|
print( "\n-- Mapping dataset labels and subwords ") |
|
|
self.tokenized_datasets = self.dataset.map( |
|
|
tokenize_and_align_labels, |
|
|
fn_kwargs = {"tokenizer":self.tokenizer, |
|
|
"id2label":self.id2label, |
|
|
"label2id":self.label2id, |
|
|
"all_word_ids":self.all_word_ids, |
|
|
"config":self.config}, |
|
|
batched=True, |
|
|
remove_columns=self.dataset.column_names, |
|
|
) |
|
|
if self.trace: |
|
|
print( self.tokenized_datasets[0]) |
|
|
|
|
|
|
|
|
def set_labels(self): |
|
|
self.LABEL_NAMES = self.corpus.LABELS |
|
|
self.id2label = {i: label for i, label in enumerate( self.LABEL_NAMES )} |
|
|
self.label2id = {v: k for k,v in self.id2label.items()} |
|
|
|
|
|
def set_label_names_bio( self, LABEL_NAMES_BIO ): |
|
|
self.LABEL_NAMES_BIO = LABEL_NAMES_BIO |
|
|
|
|
|
|
|
|
def print_trace( self ): |
|
|
super().print_trace() |
|
|
print( '\n--First sentence: original tokens and labels.\n') |
|
|
print( self.dataset[0]['tokens'] ) |
|
|
print( self.dataset[0]['labels'] ) |
|
|
print( "\n---First sentence: tokenized version:\n") |
|
|
print( self.tokenized_datasets[0] ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_corpus( task ): |
|
|
if task.strip().lower() == 'conn': |
|
|
return disrpt_io.ConnectiveCorpus() |
|
|
elif task == 'seg': |
|
|
return disrpt_io.SegmentCorpus() |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def gen( corpus, label2id, mode, add_lang_token=True,add_frame_token=True ): |
|
|
|
|
|
source = "split" |
|
|
if mode == 'conllu': |
|
|
source = "conllu" |
|
|
for doc in corpus.docs: |
|
|
lang = getattr(doc, 'lang', 'xx') if hasattr(doc, 'lang') else 'xx' |
|
|
lang_token = f"[LANG={lang}]" |
|
|
|
|
|
frame = getattr(doc, 'frame', 'xx') if hasattr(doc, 'lang') else 'xx' |
|
|
frame_token = f"[FRAME={frame}]" |
|
|
sent_list = doc.sentences[source] if source in doc.sentences else doc.sentences |
|
|
for sentence in sent_list: |
|
|
labels = [] |
|
|
tokens = [] |
|
|
if add_lang_token: |
|
|
tokens.append(lang_token) |
|
|
labels.append(-100) |
|
|
if add_frame_token: |
|
|
tokens.append(frame_token) |
|
|
labels.append(-100) |
|
|
|
|
|
for t in sentence.toks: |
|
|
tokens.append(t.form) |
|
|
if t.label == '_': |
|
|
if 'O' in label2id: |
|
|
labels.append(label2id['O']) |
|
|
else: |
|
|
labels.append(list(label2id.values())[0]) |
|
|
else: |
|
|
labels.append(label2id[t.label]) |
|
|
yield { |
|
|
"tokens": tokens, |
|
|
"labels": labels |
|
|
} |
|
|
|
|
|
|
|
|
def get_tokenizer( model_checkpoint ): |
|
|
return transformers.AutoTokenizer.from_pretrained(model_checkpoint) |
|
|
|
|
|
def tokenize_and_align_labels( dataset, tokenizer, id2label, label2id, all_word_ids, config ): |
|
|
''' |
|
|
(Done in batches) |
|
|
To preprocess our whole dataset, we need to tokenize all the inputs and |
|
|
apply align_labels_with_tokens() on all the labels. |
|
|
(with HG, we could use Dataset.map to process batches) |
|
|
The word_ids() function needs to get the index of the example we want |
|
|
the word IDs of when the inputs to the tokenizer are lists of texts |
|
|
(or in our case, list of lists of words), so we add that too: |
|
|
"tok_config" |
|
|
''' |
|
|
tokenized_inputs = tokenizer( |
|
|
dataset["tokens"], |
|
|
truncation=config["tok_config"]['truncation'], |
|
|
padding=config["tok_config"]['padding'], |
|
|
max_length=config["tok_config"]['max_length'], |
|
|
is_split_into_words=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
all_labels = dataset["labels"] |
|
|
new_labels = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, labels in enumerate(all_labels): |
|
|
word_ids = tokenized_inputs.word_ids(i) |
|
|
new_labels.append(align_labels_with_tokens(labels, word_ids, id2label, label2id, tokenizer, tokenized_inputs )) |
|
|
|
|
|
all_word_ids.append( word_ids ) |
|
|
|
|
|
tokenized_inputs["labels"] = new_labels |
|
|
return tokenized_inputs |
|
|
|
|
|
def align_labels_with_tokens(labels, word_ids, id2label, label2id, tokenizer, tokenized_inputs): |
|
|
''' |
|
|
BERT like tokenization will create new tokens, we need to align labels. |
|
|
Special tokens get a label of -100. This is because by default -100 is an |
|
|
index that is ignored in the loss function we will use (cross entropy). |
|
|
Then, each token gets the same label as the token that started the word |
|
|
it’s inside, since they are part of the same entity. For tokens inside a |
|
|
word but not at the beginning, we replace the B- with I- (since the token |
|
|
does not begin the entity). [Taken from HF website course on NER] |
|
|
''' |
|
|
count = 0 |
|
|
new_labels = [] |
|
|
current_word = None |
|
|
for word_id in word_ids: |
|
|
count += 1 |
|
|
if word_id==0: |
|
|
|
|
|
|
|
|
pass |
|
|
if word_id != current_word: |
|
|
|
|
|
current_word = word_id |
|
|
label = -100 if word_id is None else labels[word_id] |
|
|
new_labels.append(label) |
|
|
elif word_id is None: |
|
|
|
|
|
new_labels.append(-100) |
|
|
else: |
|
|
|
|
|
label = labels[word_id] |
|
|
|
|
|
if label != -100 and 'B-' in id2label[label]: |
|
|
label = -100 |
|
|
new_labels.append(label) |
|
|
return new_labels |
|
|
|
|
|
|
|
|
def retrieve_bio_labels( dataset ): |
|
|
''' |
|
|
Needed for compute_metrics, I think? It seems to be using a classic metrics for BIO |
|
|
scheme, thus we create a mapping to BIO labels, i.e.: |
|
|
'_' --> 'O' |
|
|
'Seg=B-Conn' --> 'B' |
|
|
'Seg=I-Conn' --> 'I' |
|
|
Should also work for segmentation TODO: check |
|
|
datasets: dict: DatasetSeq instances for train/dev/test |
|
|
Return: list: original label names |
|
|
list: label names mapped to BIO |
|
|
''' |
|
|
|
|
|
task = dataset.task |
|
|
LABEL_NAMES_BIO = [] |
|
|
LABEL_NAMES = dataset.LABEL_NAMES |
|
|
label2idx, idx2newl = {}, {} |
|
|
if task in ["conn", "seg"]: |
|
|
for i,l in enumerate( LABEL_NAMES ): |
|
|
label2idx[l] = i |
|
|
for l in label2idx: |
|
|
nl = '' |
|
|
if 'B' in l: |
|
|
nl = 'B' |
|
|
elif 'I' in l: |
|
|
nl = 'I' |
|
|
else: |
|
|
nl = 'O' |
|
|
idx2newl[label2idx[l]] = nl |
|
|
for i in sorted(idx2newl): |
|
|
LABEL_NAMES_BIO.append(idx2newl[i]) |
|
|
|
|
|
return LABEL_NAMES_BIO |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_language( lang ): |
|
|
|
|
|
|
|
|
if lang=="sp": lang="es" |
|
|
if lang not in LANGUAGES: |
|
|
lang = "default-multilingual" |
|
|
return lang |
|
|
|
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
import argparse, os |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description='DISCUT: reading data from disrpt_io and converting to HuggingFace' |
|
|
) |
|
|
|
|
|
parser.add_argument("-t", "--train", |
|
|
help="Training file. Default: data_test/eng.sample.rstdt/eng.sample.rstdt_train.conllu", |
|
|
default="data_test/eng.sample.rstdt/eng.sample.rstdt_train.conllu") |
|
|
|
|
|
parser.add_argument("-d", "--dev", |
|
|
help="Dev file. Default: data/eng.sample.rstdt/eng.sample.rstdt_dev.conllu", |
|
|
default="data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu") |
|
|
|
|
|
|
|
|
parser.add_argument("-o", "--output", |
|
|
help="Directory where models and pred will be saved. Default: /home/cbraud/experiments/expe_discut_2025/", |
|
|
default="") |
|
|
|
|
|
|
|
|
parser.add_argument("-c", "--config", |
|
|
help="Config file. Default: ./config_seg.json", |
|
|
default="./config_seg.json") |
|
|
|
|
|
|
|
|
parser.add_argument( '-v', '--trace', |
|
|
action='store_true', |
|
|
default=False, |
|
|
help="Whether to print full messages. If used, it will override the value in config file.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
train_path = args.train |
|
|
dev_path = args.dev |
|
|
print(dev_path) |
|
|
if not os.path.isfile(dev_path[0]): |
|
|
print( "ERROR with dev file:", dev_path) |
|
|
output_path = args.output |
|
|
config_file = args.config |
|
|
|
|
|
trace = args.trace |
|
|
|
|
|
print( '\n-[JEDIS]--PROGRAM (reader) ARGUMENTS') |
|
|
print( '| Train_path', train_path ) |
|
|
print( '| Dev_path', dev_path ) |
|
|
print( "| Output_path", output_path ) |
|
|
print( '| Config', config_file ) |
|
|
|
|
|
print( '\n-[JEDIS]--CONFIG INFO') |
|
|
config = utils.read_config( config_file ) |
|
|
utils.print_config(config) |
|
|
|
|
|
|
|
|
if not trace: |
|
|
config['trace'] = False |
|
|
|
|
|
print( "\n-[JEDIS]--READING DATASETS" ) |
|
|
|
|
|
datasets, tokenizer = read_dataset( train_path, dev_path, config, add_lang_token=True ) |
|
|
|