22import os
33from math import floor
44from typing import List , Dict , Optional
5+ import pathlib
56
67import label_studio_sdk
78from gliner import GLiNER
1516logger = logging .getLogger (__name__ )
1617
1718GLINER_MODEL_NAME = os .getenv ("GLINER_MODEL_NAME" , "urchade/gliner_medium-v2.1" )
18- logger .info (f"Loading GLINER model { GLINER_MODEL_NAME } " )
19- MODEL = GLiNER .from_pretrained (GLINER_MODEL_NAME )
2019
2120
2221class GLiNERModel (LabelStudioMLBase ):
@@ -29,10 +28,23 @@ def setup(self):
2928 """
3029 self .LABEL_STUDIO_HOST = os .getenv ('LABEL_STUDIO_URL' , 'http://localhost:8080' )
3130 self .LABEL_STUDIO_API_KEY = os .getenv ('LABEL_STUDIO_API_KEY' )
32-
33- self .set ( "model_version " , f' { self . __class__ . __name__ } -v0.0.1' )
31+ self . MODEL_DIR = os . getenv ( "MODEL_DIR" , "/data/models" )
32+ self .finetuned_model_path = os . getenv ( "FINETUNED_MODEL_PATH " , f"models/checkpoint-10" )
3433 self .threshold = float (os .getenv ('THRESHOLD' , 0.5 ))
35- self .model = MODEL
34+ self .model = None
35+
36+ def lazy_init (self ):
37+ if not self .model :
38+ try :
39+ logger .info (f"Loading Pretrained Model from { self .finetuned_model_path } " )
40+ self .model = GLiNER .from_pretrained (str (pathlib .Path (self .MODEL_DIR , self .finetuned_model_path )), local_files_only = True )
41+ self .set ("model_version" , f'{ self .__class__ .__name__ } -v0.0.2' )
42+
43+ except :
44+ # If no finetuned model, use default
45+ logger .info (f"No Pretrained Model Found. Loading GLINER model { GLINER_MODEL_NAME } " )
46+ self .model = GLiNER .from_pretrained (GLINER_MODEL_NAME )
47+ self .set ("model_version" , f'{ self .__class__ .__name__ } -v0.0.1' )
3648
3749 def convert_to_ls_annotation (self , prediction , from_name , to_name ):
3850 """
@@ -107,6 +119,8 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
107119 Parsed JSON Label config: { self .parsed_label_config }
108120 Extra params: { self .extra_params } ''' )
109121
122+ # TODO: this may result in single-time timeout for large models - consider adjusting the timeout on Label Studio side
123+ self .lazy_init ()
110124 # make predictions with currently set model
111125 from_name , to_name , value = self .label_interface .get_first_tag_occurence ('Labels' , 'Text' )
112126
@@ -149,6 +163,8 @@ def train(self, model, training_args, train_data, eval_data=None):
149163 :param train_data: the training data, as a list of dictionaries
150164 :param eval_data: the eval data
151165 """
166+ # TODO: this may result in single-time timeout for large models - consider adjusting the timeout on Label Studio side
167+ self .lazy_init ()
152168 logger .info ("Training Model" )
153169 if training_args .use_cpu == True :
154170 model = model .to ('cpu' )
@@ -168,6 +184,12 @@ def train(self, model, training_args, train_data, eval_data=None):
168184
169185 trainer .train ()
170186
187+ #Save model
188+ ckpt = str (pathlib .Path (self .MODEL_DIR , self .finetuned_model_path ))
189+ logger .info (f"Model Trained, saving to { ckpt } " )
190+ trainer .save_model (ckpt )
191+
192+
171193 def fit (self , event , data , ** kwargs ):
172194 """
173195 This method is called each time an annotation is created or updated
@@ -177,6 +199,7 @@ def fit(self, event, data, **kwargs):
177199 :param event: event type can be ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED')
178200 :param data: the payload received from the event (check [Webhook event reference](https://labelstud.io/guide/webhook_reference.html))
179201 """
202+ self .lazy_init ()
180203 # we only train the model if the "start training" button is pressed from settings.
181204 if event == "START_TRAINING" :
182205 logger .info ("Fitting model" )
@@ -211,7 +234,8 @@ def fit(self, event, data, **kwargs):
211234 num_epochs = max (1 , floor (num_steps / num_batches ))
212235
213236 training_args = TrainingArguments (
214- output_dir = "models" ,
237+ output_dir = "models/training_output" ,
238+
215239 learning_rate = 5e-6 ,
216240 weight_decay = 0.01 ,
217241 others_lr = 1e-5 ,
@@ -233,9 +257,5 @@ def fit(self, event, data, **kwargs):
233257
234258 self .train (self .model , training_args , training_data , eval_data )
235259
236- logger .info ("Saving new fine-tuned model as the default model" )
237- self .model = GLiNER .from_pretrained (f"models/checkpoint-10" , local_files_only = True )
238- model_version = int (self .model_version [- 1 ]) + 1
239- self .set ("model_version" , f'{ self .__class__ .__name__ } -v{ model_version } ' )
240260 else :
241261 logger .info ("Model training not triggered" )
0 commit comments