|
|
|
|
|
|
|
|
|
|
|
import os, sys |
|
|
import json |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import itertools |
|
|
|
|
|
import evaluate |
|
|
import disrpt_eval_2025 |
|
|
|
|
|
|
|
|
|
|
|
def prepare_compute_metrics(LABEL_NAMES): |
|
|
''' |
|
|
Return the method to be used in the trainer loop. |
|
|
For seg or conn, based on seqeval, and here ignore tokens with label |
|
|
-100 (okay ?) |
|
|
|
|
|
Parameters : |
|
|
------------ |
|
|
LABEL_NAMES: Dict |
|
|
Needed only for BIO labels, convert to the right labels for seqeval |
|
|
task: str |
|
|
Should be either 'seg', 'conn', but could be expanded to other |
|
|
sequence / classif tasks |
|
|
|
|
|
Returns : |
|
|
--------- |
|
|
compute_metrics: function |
|
|
''' |
|
|
def compute_metrics(eval_preds): |
|
|
nonlocal LABEL_NAMES |
|
|
|
|
|
|
|
|
logits, labels = eval_preds |
|
|
|
|
|
predictions = np.argmax(logits, axis=-1) |
|
|
metric = evaluate.load("seqeval") |
|
|
|
|
|
true_labels = [[LABEL_NAMES[l] for l in label if l != -100] for label in labels] |
|
|
true_predictions = [ |
|
|
[LABEL_NAMES[p] for (p, l) in zip(prediction, label) if l != -100] |
|
|
for prediction, label in zip(predictions, labels) |
|
|
] |
|
|
all_metrics = metric.compute(predictions=true_predictions, references=true_labels) |
|
|
print_metrics( all_metrics ) |
|
|
return { |
|
|
"precision": all_metrics["overall_precision"], |
|
|
"recall": all_metrics["overall_recall"], |
|
|
"f1": all_metrics["overall_f1"], |
|
|
"accuracy": all_metrics["overall_accuracy"], |
|
|
} |
|
|
return compute_metrics |
|
|
|
|
|
|
|
|
def print_metrics( all_metrics ): |
|
|
|
|
|
for p,v in all_metrics.items(): |
|
|
if '_' in p: |
|
|
print( p, v ) |
|
|
else: |
|
|
print( p+' = '+str(v)) |
|
|
|
|
|
def compute_metrics_dirspt( dataset_eval, pred_file, task='seg' ): |
|
|
print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file, |
|
|
pred_file ) |
|
|
if task == 'seg': |
|
|
|
|
|
my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg", |
|
|
dataset_eval.annotations_file, |
|
|
pred_file ) |
|
|
elif task == 'conn': |
|
|
my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn", |
|
|
dataset_eval.annotations_file, |
|
|
pred_file ) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
my_eval.compute_scores() |
|
|
my_eval.print_results() |
|
|
|
|
|
def clean_pred_file(pred_path: str, out_path: str): |
|
|
c=0 |
|
|
with open(pred_path, "r", encoding="utf8") as fin, open(out_path, "w", encoding="utf8") as fout: |
|
|
for line in fin: |
|
|
if line.strip() == "" or line.startswith("#"): |
|
|
fout.write(line) |
|
|
continue |
|
|
fields = line.strip().split("\t") |
|
|
token = fields[1] |
|
|
if token.startswith("[LANG=") or token.startswith("[FRAME="): |
|
|
c+=1 |
|
|
continue |
|
|
fout.write(line) |
|
|
print(f"we've cleaned {c} tokens") |
|
|
|
|
|
|
|
|
|
|
|
def read_config( config_file ): |
|
|
'''Read the config file for training''' |
|
|
f = open(config_file) |
|
|
config = json.load(f) |
|
|
if 'frozen' in config['trainer_config']: |
|
|
config['trainer_config']["frozen"] = update_frozen_set( config['trainer_config']["frozen"] ) |
|
|
return config |
|
|
|
|
|
def update_frozen_set( freeze ): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frozen = set() |
|
|
for spec in freeze: |
|
|
if "-" in spec: |
|
|
b, e = spec.split("-") |
|
|
frozen = frozen | set(range(int(b),int(e)+1)) |
|
|
else: |
|
|
frozen.add(int(spec)) |
|
|
return frozen |
|
|
|
|
|
def print_config(config): |
|
|
'''Print info from config dictionary''' |
|
|
print('\n'.join([ '| '+k+": "+str(v) for (k,v) in config.items() ])) |
|
|
|
|
|
|
|
|
def retrieve_files_dataset( input_path, list_dataset, mode='conllu', dset='train' ): |
|
|
if mode == 'conllu': |
|
|
pat = ".[cC][oO][nN][lL][lL][uU]" |
|
|
elif mode == 'tok': |
|
|
pat = ".[tT][oO][kK]" |
|
|
else: |
|
|
sys.exit('Unknown mode for file extension: '+mode) |
|
|
if len(list_dataset) == 0: |
|
|
return list(Path(input_path).rglob("*_"+dset+pat)) |
|
|
else: |
|
|
|
|
|
matched = [] |
|
|
for subdir in os.listdir( input_path ): |
|
|
if subdir in list_dataset: |
|
|
matched.extend( list(Path(os.path.join(input_path,subdir)).rglob("*_"+dset+pat)) ) |
|
|
return matched |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_wandb( config, model_checkpoint, annotations_file ): |
|
|
''' |
|
|
Initialize a new WANDB project to keep track of the experiments. |
|
|
Parameters |
|
|
---------- |
|
|
config : dict |
|
|
Allow to retrieve the name of the entity and project (from config file) |
|
|
model_checkpoint : |
|
|
Name of the PLM used |
|
|
annotations_file : str |
|
|
Path to the training file |
|
|
|
|
|
Returns |
|
|
------- |
|
|
None |
|
|
''' |
|
|
print("HERE WE INITIALIZE A WANDB PROJECT") |
|
|
|
|
|
import wandb |
|
|
proj_wandb = config["wandb"] |
|
|
ent_wandbd = config["wandb_ent"] |
|
|
|
|
|
|
|
|
wandb.init( |
|
|
|
|
|
project=proj_wandb, |
|
|
entity=ent_wandbd, |
|
|
|
|
|
config={ |
|
|
"model_checkpoint": model_checkpoint, |
|
|
"dataset": annotations_file, |
|
|
} |
|
|
) |
|
|
wandb.define_metric("epoch") |
|
|
wandb.define_metric("epoch") |
|
|
wandb.define_metric("f1", step_metric="batch") |
|
|
wandb.define_metric("f1", step_metric="epoch") |
|
|
|
|
|
def set_name_output_dir( output_dir, config, corpus_name ): |
|
|
''' |
|
|
Set the path name for the target directory used to store models. The name should contain |
|
|
info about the task, the PLM and the hyperparameter values. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
output_dir : str |
|
|
Path to the output directory provided by the user |
|
|
config: dict |
|
|
Information of configuration |
|
|
corpus_name: str |
|
|
Name of the corpus |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Str: Path to the output directory |
|
|
''' |
|
|
|
|
|
hyperparam = [ |
|
|
config['trainer_config']['batch_size'], |
|
|
np.format_float_positional(config['trainer_config']['learning_rate']) |
|
|
] |
|
|
output_dir = os.path.join( output_dir, |
|
|
'_'.join( [ |
|
|
corpus_name, |
|
|
config["model_name"], |
|
|
config["task"], |
|
|
'_'.join([str(p) for p in hyperparam]) |
|
|
] ) ) |
|
|
return output_dir |