@@ -96,48 +96,47 @@ def load(id):
9696DEFAULT_TIMEOUT = 3.05 # 100 requests per 5 minutes
9797
9898
99- def request_query (query , offset , limit , cache , timeout = DEFAULT_TIMEOUT ):
100- cache_key = f"results= { query } ; { offset } "
101- if cache_key in cache :
102- return cache [ cache_key ]
103-
104- url = S2_QUERY_URL
105- params = dict ( offset = offset , query = query , limit = limit )
106- reply = requests .get (url , params = params )
99+ def request_query (query , offset , limit , cache , session , timeout = DEFAULT_TIMEOUT ):
100+ params = urlencode ( dict ( query = query , offset = offset , limit = limit ))
101+ url = f" { S2_QUERY_URL } ? { params } "
102+
103+ if url in cache :
104+ return cache [ url ]
105+
106+ reply = session .get (url )
107107 response = reply .json ()
108108
109109 if "data" not in response :
110110 msg = response .get ("error" ) or response .get ("message" ) or "unknown"
111111 raise Exception (f"error while fetching { reply .url } : { msg } " )
112112
113- cache [cache_key ] = response
113+ cache [url ] = response
114114 return response
115115
116116
117- def request_paper (key , cache , timeout = DEFAULT_TIMEOUT ):
118- cache_key = urlencode (dict (paper = key ))
119- if cache_key in cache :
120- return cache [cache_key ]
121-
117+ def request_paper (key , cache , session , timeout = DEFAULT_TIMEOUT ):
122118 url = S2_PAPER_URL + quote_plus (key )
123119
120+ if url in cache :
121+ return cache [url ]
122+
124123 try :
125124 sleep (timeout )
126- data = requests .get (url ).json ()
125+ data = session .get (url ).json ()
127126 except Exception as e :
128127 logging .warning (f"failed to retrieve { key } : { e } " )
129128 return None
130129
131- if "paperId" in data :
132- cache [cache_key ] = data
133- return data
134- else :
130+ if "paperId" not in data :
135131 msg = data .get ("error" ) or data .get ("message" ) or "unknown error"
136132 logging .warning (f"failed to retrieve { key } : { msg } " )
137133 return None
138134
135+ cache [url ] = data
136+ return data
139137
140- def fetch_semanticscholar (key : set ) -> Optional [Document ]:
138+
139+ def fetch_semanticscholar (key : set , * , session = None ) -> Optional [Document ]:
141140 """Fetch SemanticScholar metadata for the given key. The key can be
142141 one of the following (see `API reference
143142 <https://www.semanticscholar.org/product/api>`_):
@@ -150,63 +149,76 @@ def fetch_semanticscholar(key: set) -> Optional[Document]:
150149 * PubMed ID (example format: `PMID:19872477`)
151150 * Corpus ID (example format: `CorpusID:37220927`)
152151
152+ :param session: The `requests.Session` to use for HTTP requests.
153153 :returns: The `Document` if it was found and `None` otherwise.
154154 """
155155
156156 if key is None :
157157 return None
158158
159+ if session is None :
160+ session = requests .Session ()
161+
159162 with shelve .open (CACHE_FILE ) as cache :
160163 if isinstance (key , DocumentIdentifier ):
161164 data = None
162165 if data is None and key .s2id :
163- data = request_paper (key .s2id , cache )
166+ data = request_paper (key .s2id , cache , session )
164167
165168 if data is None and key .doi :
166- data = request_paper (key .doi , cache )
169+ data = request_paper (key .doi , cache , session )
167170
168171 if data is None and key .pubmed :
169- data = request_paper (f"PMID:{ key .pubmed } " , cache )
172+ data = request_paper (f"PMID:{ key .pubmed } " , cache , session )
170173
171174 if data is None and key .arxivid :
172- data = request_paper (f"arXiv:{ key .arxivid } " , cache )
175+ data = request_paper (f"arXiv:{ key .arxivid } " , cache , session )
173176 else :
174- data = request_paper (key , cache )
177+ data = request_paper (key , cache , session )
175178
176179 if data is None :
177180 return None
178181
179182 return ScholarDocument (data )
180183
181184
182- def refine_semanticscholar (docs : DocumentSet ) -> Tuple [DocumentSet , DocumentSet ]:
185+ def refine_semanticscholar (docs : DocumentSet , * , session = None ) -> Tuple [DocumentSet , DocumentSet ]:
183186 """Attempt to fetch SemanticScholar metadata for each document in the
184187 given set based on their DOIs. Returns a tuple containing two sets: the
185188 documents available on SemanticScholar and the remaining documents that
186189 were not found or do not have a DOI.
190+
191+ :param session: The `requests.Session` to use for HTTP requests.
192+ :returns: The documents available on SemanticScholar and the remaining documents.
187193 """
188194
189195 def callback (doc ):
190196 if isinstance (doc , ScholarDocument ):
191197 return doc
192198
193- return fetch_semanticscholar (doc .id )
199+ return fetch_semanticscholar (doc .id , session )
194200
195201 return docs ._refine_docs (callback )
196202
197203
198- def search_semanticscholar (query : str , * , limit : int = None , batch_size : int = 100 ) -> DocumentSet :
204+ def search_semanticscholar (
205+ query : str , * , limit : int = None , batch_size : int = 100 , session = None
206+ ) -> DocumentSet :
199207 """Submit the given query to SemanticScholar API and return the results
200208 as a `DocumentSet`.
201209
202210 :param query: The search query to submit.
203211 :param limit: The maximum number of results to return.
204212 :param batch_size: The number of results to retrieve per request. Must be at most 100.
213+ :param session: The `requests.Session` to use for HTTP requests.
205214 """
206215
207216 if not query :
208217 raise Exception ("no query specified in `search_semanticscholar`" )
209218
219+ if session is None :
220+ session = requests .Session ()
221+
210222 docs = []
211223
212224 with shelve .open (CACHE_FILE ) as cache :
@@ -215,7 +227,7 @@ def search_semanticscholar(query: str, *, limit: int = None, batch_size: int = 1
215227 while True :
216228 offset = len (paper_ids )
217229
218- response = request_query (query , offset , batch_size , cache )
230+ response = request_query (query , offset , batch_size , cache , session )
219231 if not response :
220232 break
221233
@@ -235,7 +247,7 @@ def search_semanticscholar(query: str, *, limit: int = None, batch_size: int = 1
235247 break
236248
237249 for paper_id in progress_bar (paper_ids ):
238- doc = request_paper (paper_id , cache )
250+ doc = request_paper (paper_id , cache , session )
239251
240252 if doc :
241253 docs .append (ScholarDocument (doc ))
0 commit comments