@@ -227,7 +227,7 @@ def build_unit_name(dimensions, lang="en_US"):
227227
228228
229229###############################################################################
230- def get_unit_from_dimensions (dimensions , text , lang = "en_US" ):
230+ def get_unit_from_dimensions (dimensions , text , lang = "en_US" , classifier_path = None ):
231231 """
232232 Reconcile a unit based on its dimensionality.
233233 """
@@ -241,7 +241,7 @@ def get_unit_from_dimensions(dimensions, text, lang="en_US"):
241241 unit = cls .Unit (
242242 name = build_unit_name (dimensions , lang ),
243243 dimensions = dimensions ,
244- entity = get_entity_from_dimensions (dimensions , text , lang ),
244+ entity = get_entity_from_dimensions (dimensions , text , lang , classifier_path ),
245245 )
246246
247247 # Carry on original composition
@@ -268,7 +268,7 @@ def infer_name(unit):
268268
269269
270270###############################################################################
271- def get_entity_from_dimensions (dimensions , text , lang = "en_US" ):
271+ def get_entity_from_dimensions (dimensions , text , lang = "en_US" , classifier_path = None ):
272272 """
273273 Infer the underlying entity of a unit (e.g. "volume" for "m^3") based on
274274 its dimensionality.
@@ -282,7 +282,7 @@ def get_entity_from_dimensions(dimensions, text, lang="en_US"):
282282 final_derived = sorted (new_derived , key = lambda x : x ["base" ])
283283 key = load .get_key_from_dimensions (final_derived )
284284
285- ent = dis .disambiguate_entity (key , text , lang )
285+ ent = dis .disambiguate_entity (key , text , lang , classifier_path = classifier_path )
286286 if ent is None :
287287 _LOGGER .debug ("\t Could not find entity for: %s" , key )
288288 ent = cls .Entity (name = "unknown" , dimensions = new_derived )
@@ -299,7 +299,7 @@ def parse_unit(item, unit, slash, lang="en_US"):
299299
300300
301301###############################################################################
302- def get_unit (item , text , lang = "en_US" ):
302+ def get_unit (item , text , lang = "en_US" , classifier_path = None ):
303303 """
304304 Extract unit from regex hit.
305305 """
@@ -362,10 +362,10 @@ def get_unit(item, text, lang="en_US"):
362362 # Determine which unit follows
363363 if unit :
364364 unit_surface , power = parse_unit (item , unit , slash , lang )
365- base = dis .disambiguate_unit (unit_surface , text , lang )
365+ base = dis .disambiguate_unit (unit_surface , text , lang , classifier_path )
366366 derived += [{"base" : base , "power" : power , "surface" : unit_surface }]
367367
368- unit = get_unit_from_dimensions (derived , text , lang )
368+ unit = get_unit_from_dimensions (derived , text , lang , classifier_path )
369369
370370 _LOGGER .debug ("\t Unit: %s" , unit )
371371 _LOGGER .debug ("\t Entity: %s" , unit .entity )
@@ -424,14 +424,23 @@ def is_quote_artifact(orig_text, span):
424424
425425###############################################################################
426426def build_quantity (
427- orig_text , text , item , values , unit , surface , span , uncert , lang = "en_US"
427+ orig_text ,
428+ text ,
429+ item ,
430+ values ,
431+ unit ,
432+ surface ,
433+ span ,
434+ uncert ,
435+ lang = "en_US" ,
436+ classifier_path = None ,
428437):
429438 """
430439 Build a Quantity object out of extracted information.
431440 Takes care of caveats and common errors
432441 """
433442 return _get_parser (lang ).build_quantity (
434- orig_text , text , item , values , unit , surface , span , uncert
443+ orig_text , text , item , values , unit , surface , span , uncert , classifier_path
435444 )
436445
437446
@@ -479,8 +488,8 @@ def handle_consecutive_quantities(quantities, context):
479488 if range_span :
480489 if q1 .unit .name == q2 .unit .name or q1 .unit .name == "dimensionless" :
481490 if (
482- q1 .uncertainty == None
483- and q2 .uncertainty == None
491+ q1 .uncertainty is None
492+ and q2 .uncertainty is None
484493 and q1 .value != q2 .value
485494 ):
486495 a , b = (q1 , q2 ) if q2 .value > q1 .value else (q2 , q1 )
@@ -505,9 +514,28 @@ def handle_consecutive_quantities(quantities, context):
505514
506515
507516###############################################################################
508- def parse (text , lang = "en_US" , verbose = False ) -> List [cls .Quantity ]:
517+ def parse (
518+ text , lang = "en_US" , verbose = False , classifier_path = None
519+ ) -> List [cls .Quantity ]:
509520 """
510521 Extract all quantities from unstructured text.
522+
523+ Parameters
524+ ----------
525+ text : str
526+ Text to parse.
527+ lang : str
528+ Language of the text. Default is "en_US".
529+ verbose : bool
530+ If True, print debug information. Default is False.
531+ classifier_path : str
532+ Path to the classifier model. Default is None, which uses the default
533+ model for the given language.
534+
535+ Returns
536+ -------
537+ quantities : List[Quantity]
538+ List of quantities found in the text.
511539 """
512540
513541 log_format = "%(asctime)s --- %(message)s"
@@ -533,10 +561,19 @@ def parse(text, lang="en_US", verbose=False) -> List[cls.Quantity]:
533561 try :
534562 uncert , values = get_values (item , lang )
535563
536- unit , unit_shortening = get_unit (item , text )
564+ unit , unit_shortening = get_unit (item , text , lang , classifier_path )
537565 surface , span = get_surface (shifts , orig_text , item , text , unit_shortening )
538566 objs = build_quantity (
539- orig_text , text , item , values , unit , surface , span , uncert , lang
567+ orig_text ,
568+ text ,
569+ item ,
570+ values ,
571+ unit ,
572+ surface ,
573+ span ,
574+ uncert ,
575+ lang ,
576+ classifier_path ,
540577 )
541578 if objs is not None :
542579 quantities += objs
0 commit comments