Skip to content

Commit b7c5586

Browse files
authored
Merge pull request #223 from adhardydm/use-custom-model-file
Use custom model file
2 parents 3af1e1e + 240a723 commit b7c5586

File tree

9 files changed

+283
-58
lines changed

9 files changed

+283
-58
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ script:
5050
- coverage run -a --source=quantulum3 setup.py test -s quantulum3.tests.test_classifier.ClassifierBuild
5151
- coverage run -a --source=quantulum3 setup.py test -s quantulum3.tests.test_classifier.ClassifierTest
5252
- coverage run -a --source=quantulum3 setup.py test -s quantulum3.tests.test_scripts.TrainScriptTest
53+
- coverage run -a --source=quantulum3 setup.py test -s quantulum3.tests.test_load.TestCached
5354

5455
after_success:
5556
- coverage report

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,14 @@ converted to a `.magnitude` file on-the-run. Check out
184184
and [Magnitude](https://github.com/plasticityai/magnitude) for more information.
185185

186186

187+
To use your custom model, pass the path to the trained model file to the
188+
parser:
189+
190+
```pyton
191+
parser = Parser.parse(classifier_path="path/to/model")
192+
```
193+
194+
187195
Manipulation
188196
------------
189197

quantulum3/_lang/en_US/parser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ def parse_unit(_, unit, slash):
234234

235235

236236
###############################################################################
237-
def build_quantity(orig_text, text, item, values, unit, surface, span, uncert):
237+
def build_quantity(
238+
orig_text, text, item, values, unit, surface, span, uncert, classifier_path=None
239+
):
238240
"""
239241
Build a Quantity object out of extracted information.
240242
"""
@@ -425,7 +427,7 @@ def build_quantity(orig_text, text, item, values, unit, surface, span, uncert):
425427
if dimension_change:
426428
if unit.original_dimensions:
427429
unit = parser.get_unit_from_dimensions(
428-
unit.original_dimensions, orig_text, lang
430+
unit.original_dimensions, orig_text, lang, classifier_path
429431
)
430432
else:
431433
unit = load.units(lang).names["dimensionless"]

quantulum3/classifier.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import multiprocessing
99
import os
10+
import warnings
1011

1112
import pkg_resources
1213

@@ -24,6 +25,11 @@
2425
SGDClassifier, TfidfVectorizer = None, None
2526
USE_CLF = False
2627

28+
warnings.warn(
29+
"Classifier dependencies not installed. Run pip install quantulum3[classifier] "
30+
"to install them. The classifer helps to dissambiguate units."
31+
)
32+
2733
try:
2834
import wikipedia
2935
except ImportError:
@@ -203,9 +209,20 @@ def train_classifier(
203209

204210
###############################################################################
205211
class Classifier(object):
206-
def __init__(self, obj=None, lang="en_US"):
212+
def __init__(self, classifier_object=None, lang="en_US", classifier_path=None):
207213
"""
208214
Load the intent classifier
215+
216+
Parameters
217+
----------
218+
obj : dict
219+
Classifier object as returned by train_classifier
220+
lang : str
221+
Language to use (ignored if a classifier object or path is given)
222+
classifier_path : str
223+
Path a joblib file containing the classifier. If None, the
224+
classifier will be loaded from the default location for the given
225+
language.
209226
"""
210227
self.tfidf_model = None
211228
self.classifier = None
@@ -214,47 +231,53 @@ def __init__(self, obj=None, lang="en_US"):
214231
if not USE_CLF:
215232
return
216233

217-
if not obj:
218-
path = language.topdir(lang).joinpath("clf.joblib")
219-
with path.open("rb") as file:
220-
obj = joblib.load(file)
234+
if not classifier_object:
235+
if classifier_path is None:
236+
classifier_path = language.topdir(lang).joinpath("clf.joblib")
237+
with classifier_path.open("rb") as file:
238+
classifier_object = joblib.load(file)
221239

222240
cur_scipy_version = pkg_resources.get_distribution("scikit-learn").version
223-
if cur_scipy_version != obj.get("scikit-learn_version"): # pragma: no cover
241+
if cur_scipy_version != classifier_object.get(
242+
"scikit-learn_version"
243+
): # pragma: no cover
224244
_LOGGER.warning(
225245
"The classifier was built using a different scikit-learn "
226246
"version (={}, !={}). The disambiguation tool could behave "
227247
"unexpectedly. Consider running classifier.train_classfier()".format(
228-
obj.get("scikit-learn_version"), cur_scipy_version
248+
classifier_object.get("scikit-learn_version"), cur_scipy_version
229249
)
230250
)
231251

232-
self.tfidf_model = obj["tfidf_model"]
233-
self.classifier = obj["clf"]
234-
self.target_names = obj["target_names"]
252+
self.tfidf_model = classifier_object["tfidf_model"]
253+
self.classifier = classifier_object["clf"]
254+
self.target_names = classifier_object["target_names"]
235255

236256

237257
@cached
238-
def classifier(lang="en_US"):
258+
def classifier(lang: str = "en_US", classifier_path: str = None) -> Classifier:
239259
"""
240260
Cached classifier object
241-
:param lang:
242-
:return:
261+
:param lang: language
262+
:param classifier_path: path to a joblib file containing the classifier
263+
:return: Classifier object
243264
"""
244-
return Classifier(lang=lang)
265+
return Classifier(lang=lang, classifier_path=classifier_path)
245266

246267

247268
###############################################################################
248-
def disambiguate_entity(key, text, lang="en_US"):
269+
def disambiguate_entity(key, text, lang="en_US", classifier_path=None):
249270
"""
250271
Resolve ambiguity between entities with same dimensionality.
251272
"""
252273

253274
new_ent = next(iter(load.entities(lang).derived[key]))
254275
if len(load.entities(lang).derived[key]) > 1:
255-
transformed = classifier(lang).tfidf_model.transform([clean_text(text, lang)])
256-
scores = classifier(lang).classifier.predict_proba(transformed).tolist()[0]
257-
scores = zip(scores, classifier(lang).target_names)
276+
classifier_: Classifier = classifier(lang, classifier_path)
277+
278+
transformed = classifier_.tfidf_model.transform([clean_text(text, lang)])
279+
scores = classifier_.classifier.predict_proba(transformed).tolist()[0]
280+
scores = zip(scores, classifier_.target_names)
258281

259282
# Filter for possible names
260283
names = [i.name for i in load.entities(lang).derived[key]]
@@ -271,7 +294,7 @@ def disambiguate_entity(key, text, lang="en_US"):
271294

272295

273296
###############################################################################
274-
def disambiguate_unit(unit, text, lang="en_US"):
297+
def disambiguate_unit(unit, text, lang="en_US", classifier_path=None):
275298
"""
276299
Resolve ambiguity between units with same names, symbols or abbreviations.
277300
"""
@@ -286,9 +309,11 @@ def disambiguate_unit(unit, text, lang="en_US"):
286309
return load.units(lang).names.get("unk")
287310

288311
if len(new_unit) > 1:
289-
transformed = classifier(lang).tfidf_model.transform([clean_text(text, lang)])
290-
scores = classifier(lang).classifier.predict_proba(transformed).tolist()[0]
291-
scores = zip(scores, classifier(lang).target_names)
312+
classifier_: Classifier = classifier(lang, classifier_path)
313+
314+
transformed = classifier_.tfidf_model.transform([clean_text(text, lang)])
315+
scores = classifier_.classifier.predict_proba(transformed).tolist()[0]
316+
scores = zip(scores, classifier_.target_names)
292317

293318
# Filter for possible names
294319
names = [i.name for i in new_unit]

quantulum3/disambiguate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010

1111
###############################################################################
12-
def disambiguate_unit(unit_surface, text, lang="en_US"):
12+
def disambiguate_unit(unit_surface, text, lang="en_US", classifier_path=None):
1313
"""
1414
Resolve ambiguity between units with same names, symbols or abbreviations.
1515
:returns (str) unit name of the resolved unit
1616
"""
1717
if clf.USE_CLF:
18-
base = clf.disambiguate_unit(unit_surface, text, lang).name
18+
base = clf.disambiguate_unit(unit_surface, text, lang, classifier_path).name
1919
else:
2020
base = (
2121
load.units(lang).symbols[unit_surface]
@@ -38,13 +38,13 @@ def disambiguate_unit(unit_surface, text, lang="en_US"):
3838

3939

4040
###############################################################################
41-
def disambiguate_entity(key, text, lang="en_US"):
41+
def disambiguate_entity(key, text, lang="en_US", classifier_path=None):
4242
"""
4343
Resolve ambiguity between entities with same dimensionality.
4444
"""
4545
try:
4646
if clf.USE_CLF:
47-
ent = clf.disambiguate_entity(key, text, lang)
47+
ent = clf.disambiguate_entity(key, text, lang, classifier_path)
4848
else:
4949
derived = load.entities().derived[key]
5050
if len(derived) > 1:

quantulum3/load.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
TOPDIR = Path(__file__).parent or Path(".")
1616

1717
###############################################################################
18-
_CACHE_DICT = {}
18+
_CACHE_DICT = defaultdict(dict)
19+
20+
21+
def clear_cache():
22+
"""
23+
Useful for testing.
24+
"""
25+
_CACHE_DICT.clear()
1926

2027

2128
def cached(funct):
@@ -28,12 +35,15 @@ def cached(funct):
2835
"""
2936
assert callable(funct)
3037

31-
def cached_function(lang="en_US"):
38+
def cached_function(*args, **kwargs):
39+
# create a hash of args and kwargs
40+
args_hash = hash((args, frozenset(kwargs.items())))
41+
3242
try:
33-
return _CACHE_DICT[id(funct)][lang]
43+
return _CACHE_DICT[id(funct)][args_hash]
3444
except KeyError:
35-
result = funct(lang)
36-
_CACHE_DICT[id(funct)] = {lang: result}
45+
result = funct(*args, **kwargs)
46+
_CACHE_DICT[id(funct)].update({args_hash: result})
3747
return result
3848

3949
return cached_function

quantulum3/parser.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def build_unit_name(dimensions, lang="en_US"):
227227

228228

229229
###############################################################################
230-
def get_unit_from_dimensions(dimensions, text, lang="en_US"):
230+
def get_unit_from_dimensions(dimensions, text, lang="en_US", classifier_path=None):
231231
"""
232232
Reconcile a unit based on its dimensionality.
233233
"""
@@ -241,7 +241,7 @@ def get_unit_from_dimensions(dimensions, text, lang="en_US"):
241241
unit = cls.Unit(
242242
name=build_unit_name(dimensions, lang),
243243
dimensions=dimensions,
244-
entity=get_entity_from_dimensions(dimensions, text, lang),
244+
entity=get_entity_from_dimensions(dimensions, text, lang, classifier_path),
245245
)
246246

247247
# Carry on original composition
@@ -268,7 +268,7 @@ def infer_name(unit):
268268

269269

270270
###############################################################################
271-
def get_entity_from_dimensions(dimensions, text, lang="en_US"):
271+
def get_entity_from_dimensions(dimensions, text, lang="en_US", classifier_path=None):
272272
"""
273273
Infer the underlying entity of a unit (e.g. "volume" for "m^3") based on
274274
its dimensionality.
@@ -282,7 +282,7 @@ def get_entity_from_dimensions(dimensions, text, lang="en_US"):
282282
final_derived = sorted(new_derived, key=lambda x: x["base"])
283283
key = load.get_key_from_dimensions(final_derived)
284284

285-
ent = dis.disambiguate_entity(key, text, lang)
285+
ent = dis.disambiguate_entity(key, text, lang, classifier_path=classifier_path)
286286
if ent is None:
287287
_LOGGER.debug("\tCould not find entity for: %s", key)
288288
ent = cls.Entity(name="unknown", dimensions=new_derived)
@@ -299,7 +299,7 @@ def parse_unit(item, unit, slash, lang="en_US"):
299299

300300

301301
###############################################################################
302-
def get_unit(item, text, lang="en_US"):
302+
def get_unit(item, text, lang="en_US", classifier_path=None):
303303
"""
304304
Extract unit from regex hit.
305305
"""
@@ -362,10 +362,10 @@ def get_unit(item, text, lang="en_US"):
362362
# Determine which unit follows
363363
if unit:
364364
unit_surface, power = parse_unit(item, unit, slash, lang)
365-
base = dis.disambiguate_unit(unit_surface, text, lang)
365+
base = dis.disambiguate_unit(unit_surface, text, lang, classifier_path)
366366
derived += [{"base": base, "power": power, "surface": unit_surface}]
367367

368-
unit = get_unit_from_dimensions(derived, text, lang)
368+
unit = get_unit_from_dimensions(derived, text, lang, classifier_path)
369369

370370
_LOGGER.debug("\tUnit: %s", unit)
371371
_LOGGER.debug("\tEntity: %s", unit.entity)
@@ -424,14 +424,23 @@ def is_quote_artifact(orig_text, span):
424424

425425
###############################################################################
426426
def build_quantity(
427-
orig_text, text, item, values, unit, surface, span, uncert, lang="en_US"
427+
orig_text,
428+
text,
429+
item,
430+
values,
431+
unit,
432+
surface,
433+
span,
434+
uncert,
435+
lang="en_US",
436+
classifier_path=None,
428437
):
429438
"""
430439
Build a Quantity object out of extracted information.
431440
Takes care of caveats and common errors
432441
"""
433442
return _get_parser(lang).build_quantity(
434-
orig_text, text, item, values, unit, surface, span, uncert
443+
orig_text, text, item, values, unit, surface, span, uncert, classifier_path
435444
)
436445

437446

@@ -479,8 +488,8 @@ def handle_consecutive_quantities(quantities, context):
479488
if range_span:
480489
if q1.unit.name == q2.unit.name or q1.unit.name == "dimensionless":
481490
if (
482-
q1.uncertainty == None
483-
and q2.uncertainty == None
491+
q1.uncertainty is None
492+
and q2.uncertainty is None
484493
and q1.value != q2.value
485494
):
486495
a, b = (q1, q2) if q2.value > q1.value else (q2, q1)
@@ -505,9 +514,28 @@ def handle_consecutive_quantities(quantities, context):
505514

506515

507516
###############################################################################
508-
def parse(text, lang="en_US", verbose=False) -> List[cls.Quantity]:
517+
def parse(
518+
text, lang="en_US", verbose=False, classifier_path=None
519+
) -> List[cls.Quantity]:
509520
"""
510521
Extract all quantities from unstructured text.
522+
523+
Parameters
524+
----------
525+
text : str
526+
Text to parse.
527+
lang : str
528+
Language of the text. Default is "en_US".
529+
verbose : bool
530+
If True, print debug information. Default is False.
531+
classifier_path : str
532+
Path to the classifier model. Default is None, which uses the default
533+
model for the given language.
534+
535+
Returns
536+
-------
537+
quantities : List[Quantity]
538+
List of quantities found in the text.
511539
"""
512540

513541
log_format = "%(asctime)s --- %(message)s"
@@ -533,10 +561,19 @@ def parse(text, lang="en_US", verbose=False) -> List[cls.Quantity]:
533561
try:
534562
uncert, values = get_values(item, lang)
535563

536-
unit, unit_shortening = get_unit(item, text)
564+
unit, unit_shortening = get_unit(item, text, lang, classifier_path)
537565
surface, span = get_surface(shifts, orig_text, item, text, unit_shortening)
538566
objs = build_quantity(
539-
orig_text, text, item, values, unit, surface, span, uncert, lang
567+
orig_text,
568+
text,
569+
item,
570+
values,
571+
unit,
572+
surface,
573+
span,
574+
uncert,
575+
lang,
576+
classifier_path,
540577
)
541578
if objs is not None:
542579
quantities += objs

0 commit comments

Comments
 (0)