Spaces:
Sleeping
Sleeping
Update legaldoc_summarizer/app.py
Browse files- legaldoc_summarizer/app.py +21 -4
legaldoc_summarizer/app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from shared.hf_helpers import build_pipeline
|
| 3 |
import yaml
|
|
@@ -10,13 +11,29 @@ def main():
|
|
| 10 |
with open(CONFIG_PATH) as f:
|
| 11 |
cfg = yaml.safe_load(f)
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
@st.cache_resource
|
| 16 |
-
def get_pipeline(model_name):
|
| 17 |
-
return build_pipeline(model_name)
|
| 18 |
|
| 19 |
-
pipe = get_pipeline(model_name)
|
| 20 |
|
| 21 |
st.write("Paste a contract clause or judgment text below:")
|
| 22 |
text = st.text_area("Clause or Legal Text", height=250)
|
|
|
|
| 1 |
+
import os
|
| 2 |
import streamlit as st
|
| 3 |
from shared.hf_helpers import build_pipeline
|
| 4 |
import yaml
|
|
|
|
| 11 |
with open(CONFIG_PATH) as f:
|
| 12 |
cfg = yaml.safe_load(f)
|
| 13 |
|
| 14 |
+
base_model = cfg["base_model"]
|
| 15 |
+
finetuned_model = cfg.get("finetuned_model") or os.getenv("LEGALDOC_MODEL_ID")
|
| 16 |
+
local_model_dir = Path(cfg.get("finetuned_local_dir", "models/legaldoc_summarizer"))
|
| 17 |
+
|
| 18 |
+
model_options = [base_model]
|
| 19 |
+
if finetuned_model:
|
| 20 |
+
model_options.append(finetuned_model)
|
| 21 |
+
elif local_model_dir.exists():
|
| 22 |
+
model_options.append(str(local_model_dir))
|
| 23 |
+
else:
|
| 24 |
+
st.info(
|
| 25 |
+
"Using the base model until a fine-tuned checkpoint is available. "
|
| 26 |
+
"Train a model to populate `models/legaldoc_summarizer` or set `LEGALDOC_MODEL_ID` / `finetuned_model`."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
model_name = st.selectbox("Model:", model_options)
|
| 30 |
+
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
|
| 31 |
|
| 32 |
@st.cache_resource
|
| 33 |
+
def get_pipeline(model_name, token):
|
| 34 |
+
return build_pipeline(model_name, token=token)
|
| 35 |
|
| 36 |
+
pipe = get_pipeline(model_name, hf_token)
|
| 37 |
|
| 38 |
st.write("Paste a contract clause or judgment text below:")
|
| 39 |
text = st.text_area("Clause or Legal Text", height=250)
|