|
|
from transformers import Pipeline, AutoModelForTokenClassification |
|
|
import numpy as np |
|
|
from eval import retrieve_predictions, align_tokens_labels_from_wordids |
|
|
from reading import read_dataset |
|
|
from utils import read_config |
|
|
|
|
|
|
|
|
|
|
|
def write_sentences_to_format(sentences: list[str], filename: str): |
|
|
""" |
|
|
Écrit une phrase dans un fichier, un mot par ligne, avec le format : |
|
|
index<TAB>mot<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>Seg=... |
|
|
""" |
|
|
|
|
|
if not sentences: |
|
|
return "" |
|
|
if isinstance(sentences, str): |
|
|
sentences=[sentences] |
|
|
import sys |
|
|
sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n") |
|
|
|
|
|
full="# newdoc_id = GUM_academic_discrimination\n" |
|
|
for sentence in sentences: |
|
|
words = sentence.strip().split() |
|
|
for i, word in enumerate(words, start=1): |
|
|
|
|
|
seg_label = "B-seg" if i == 1 or word[0].isupper() else "O" |
|
|
line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n" |
|
|
full+=line |
|
|
if filename: |
|
|
with open(filename, "w", encoding="utf-8") as f: |
|
|
f.write(full) |
|
|
|
|
|
return full |
|
|
|
|
|
|
|
|
class DiscoursePipeline(Pipeline): |
|
|
def __init__(self, model, tokenizer, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs): |
|
|
auto_model = AutoModelForTokenClassification.from_pretrained(model) |
|
|
super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs) |
|
|
self.config = {"model_checkpoint": model, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{ |
|
|
"padding":"max_length", |
|
|
"truncation":True, |
|
|
"max_length": 512 |
|
|
}} |
|
|
self.model = model |
|
|
self.output_folder = output_folder |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
|
|
|
preprocess_params = {} |
|
|
forward_params = {} |
|
|
postprocess_params = {} |
|
|
return preprocess_params, forward_params, postprocess_params |
|
|
|
|
|
def preprocess(self, text:str): |
|
|
self.original_text=text |
|
|
formatted_text=write_sentences_to_format(text.split("\n"), filename=None) |
|
|
dataset, _ = read_dataset( |
|
|
formatted_text, |
|
|
output_path=self.output_folder, |
|
|
config=self.config, |
|
|
add_lang_token=True, |
|
|
add_frame_token=True, |
|
|
) |
|
|
return {"dataset": dataset} |
|
|
|
|
|
def _forward(self, inputs): |
|
|
dataset = inputs["dataset"] |
|
|
preds_from_model, label_ids, _ = retrieve_predictions( |
|
|
self.model, dataset, self.output_folder, self.tokenizer, self.config |
|
|
) |
|
|
return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset} |
|
|
|
|
|
def postprocess(self, outputs): |
|
|
preds = np.argmax(outputs["preds"], axis=-1) |
|
|
predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer) |
|
|
edus=text_to_edus(self.original_text, predictions) |
|
|
return edus |
|
|
|
|
|
def get_plain_text_from_format(formatted_text:str) -> str: |
|
|
""" |
|
|
Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères. |
|
|
""" |
|
|
formatted_text=formatted_text.split("\n") |
|
|
s="" |
|
|
for line in formatted_text: |
|
|
if not line.startswith("#"): |
|
|
if len(line.split("\t"))>1: |
|
|
s+=line.split("\t")[1]+" " |
|
|
return s.strip() |
|
|
|
|
|
|
|
|
def get_preds_from_format(formatted_text:str) -> str: |
|
|
""" |
|
|
Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères. |
|
|
""" |
|
|
formatted_text=formatted_text.split("\n") |
|
|
s="" |
|
|
for line in formatted_text: |
|
|
if not line.startswith("#"): |
|
|
if len(line.split("\t"))>1: |
|
|
s+=line.split("\t")[-1]+" " |
|
|
return s.strip() |
|
|
|
|
|
|
|
|
def text_to_edus(text: str, labels: list[str]) -> list[str]: |
|
|
""" |
|
|
Découpe un texte brut en EDUs à partir d'une séquence de labels BIO. |
|
|
|
|
|
Args: |
|
|
text (str): Le texte brut (séquence de mots séparés par des espaces). |
|
|
labels (list[str]): La séquence de labels BIO (B, I, O), |
|
|
de même longueur que le nombre de tokens du texte. |
|
|
|
|
|
Returns: |
|
|
list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte). |
|
|
""" |
|
|
words = text.strip().split() |
|
|
if len(words) != len(labels): |
|
|
raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels") |
|
|
|
|
|
edus = [] |
|
|
current_edu = [] |
|
|
|
|
|
for word, label in zip(words, labels): |
|
|
if label == "Conn=O" or label == "Seg=O": |
|
|
current_edu.append(word) |
|
|
|
|
|
elif label == "Conn=B-conn" or label == "Seg=B-seg": |
|
|
|
|
|
if current_edu: |
|
|
|
|
|
edus.append(" ".join(current_edu)) |
|
|
current_edu = [] |
|
|
current_edu.append(word) |
|
|
|
|
|
|
|
|
if current_edu: |
|
|
edus.append(" ".join(current_edu)) |
|
|
|
|
|
return edus |
|
|
|