Skip to content

Commit f27c3e6

Browse files
fix: Gliner (#717)
Co-authored-by: Max Tkachenko <[email protected]>
1 parent 317474a commit f27c3e6

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

label_studio_ml/examples/gliner/docker-compose.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ services:
2222
# specify the model directory (likely you don't need to change this)
2323
- MODEL_DIR=/data/models
2424

25+
# Path to your saved finetuned model
26+
- FINETUNED_MODEL_PATH=finetuned_model
27+
2528
# Specify the Label Studio URL and API key to access
2629
# uploaded, local storage and cloud storage files.
2730
# Do not use 'localhost' as it does not work within Docker containers.

label_studio_ml/examples/gliner/model.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from math import floor
44
from typing import List, Dict, Optional
5+
import pathlib
56

67
import label_studio_sdk
78
from gliner import GLiNER
@@ -15,8 +16,6 @@
1516
logger = logging.getLogger(__name__)
1617

1718
GLINER_MODEL_NAME = os.getenv("GLINER_MODEL_NAME", "urchade/gliner_medium-v2.1")
18-
logger.info(f"Loading GLINER model {GLINER_MODEL_NAME}")
19-
MODEL = GLiNER.from_pretrained(GLINER_MODEL_NAME)
2019

2120

2221
class GLiNERModel(LabelStudioMLBase):
@@ -29,10 +28,23 @@ def setup(self):
2928
"""
3029
self.LABEL_STUDIO_HOST = os.getenv('LABEL_STUDIO_URL', 'http://localhost:8080')
3130
self.LABEL_STUDIO_API_KEY = os.getenv('LABEL_STUDIO_API_KEY')
32-
33-
self.set("model_version", f'{self.__class__.__name__}-v0.0.1')
31+
self.MODEL_DIR = os.getenv("MODEL_DIR", "/data/models")
32+
self.finetuned_model_path = os.getenv("FINETUNED_MODEL_PATH", f"models/checkpoint-10")
3433
self.threshold = float(os.getenv('THRESHOLD', 0.5))
35-
self.model = MODEL
34+
self.model = None
35+
36+
def lazy_init(self):
37+
if not self.model:
38+
try:
39+
logger.info(f"Loading Pretrained Model from {self.finetuned_model_path}")
40+
self.model = GLiNER.from_pretrained(str(pathlib.Path(self.MODEL_DIR, self.finetuned_model_path)), local_files_only=True)
41+
self.set("model_version", f'{self.__class__.__name__}-v0.0.2')
42+
43+
except:
44+
# If no finetuned model, use default
45+
logger.info(f"No Pretrained Model Found. Loading GLINER model {GLINER_MODEL_NAME}")
46+
self.model = GLiNER.from_pretrained(GLINER_MODEL_NAME)
47+
self.set("model_version", f'{self.__class__.__name__}-v0.0.1')
3648

3749
def convert_to_ls_annotation(self, prediction, from_name, to_name):
3850
"""
@@ -107,6 +119,8 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
107119
Parsed JSON Label config: {self.parsed_label_config}
108120
Extra params: {self.extra_params}''')
109121

122+
# TODO: this may result in single-time timeout for large models - consider adjusting the timeout on Label Studio side
123+
self.lazy_init()
110124
# make predictions with currently set model
111125
from_name, to_name, value = self.label_interface.get_first_tag_occurence('Labels', 'Text')
112126

@@ -149,6 +163,8 @@ def train(self, model, training_args, train_data, eval_data=None):
149163
:param train_data: the training data, as a list of dictionaries
150164
:param eval_data: the eval data
151165
"""
166+
# TODO: this may result in single-time timeout for large models - consider adjusting the timeout on Label Studio side
167+
self.lazy_init()
152168
logger.info("Training Model")
153169
if training_args.use_cpu == True:
154170
model = model.to('cpu')
@@ -168,6 +184,12 @@ def train(self, model, training_args, train_data, eval_data=None):
168184

169185
trainer.train()
170186

187+
#Save model
188+
ckpt = str(pathlib.Path(self.MODEL_DIR, self.finetuned_model_path))
189+
logger.info(f"Model Trained, saving to {ckpt} ")
190+
trainer.save_model(ckpt)
191+
192+
171193
def fit(self, event, data, **kwargs):
172194
"""
173195
This method is called each time an annotation is created or updated
@@ -177,6 +199,7 @@ def fit(self, event, data, **kwargs):
177199
:param event: event type can be ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED')
178200
:param data: the payload received from the event (check [Webhook event reference](https://labelstud.io/guide/webhook_reference.html))
179201
"""
202+
self.lazy_init()
180203
# we only train the model if the "start training" button is pressed from settings.
181204
if event == "START_TRAINING":
182205
logger.info("Fitting model")
@@ -211,7 +234,8 @@ def fit(self, event, data, **kwargs):
211234
num_epochs = max(1, floor(num_steps / num_batches))
212235

213236
training_args = TrainingArguments(
214-
output_dir="models",
237+
output_dir="models/training_output",
238+
215239
learning_rate=5e-6,
216240
weight_decay=0.01,
217241
others_lr=1e-5,
@@ -233,9 +257,5 @@ def fit(self, event, data, **kwargs):
233257

234258
self.train(self.model, training_args, training_data, eval_data)
235259

236-
logger.info("Saving new fine-tuned model as the default model")
237-
self.model = GLiNER.from_pretrained(f"models/checkpoint-10", local_files_only=True)
238-
model_version = int(self.model_version[-1]) + 1
239-
self.set("model_version", f'{self.__class__.__name__}-v{model_version}')
240260
else:
241261
logger.info("Model training not triggered")

0 commit comments

Comments
 (0)