Skip to content

Commit

Permalink
add seed and code to allow multi-sample averaging for variance reduct…
Browse files Browse the repository at this point in the history
…ion. also run linter.
  • Loading branch information
thashim committed Nov 15, 2023
1 parent 0575996 commit 8c8051e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
27 changes: 19 additions & 8 deletions arxiv_scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,24 @@ def __hash__(self):


def is_earlier(ts1, ts2):
return int(ts1.replace('.','')) < int(ts2.replace('.',''))
return int(ts1.replace(".", "")) < int(ts2.replace(".", ""))


def get_papers_from_arxiv_api(area: str, timestamp, last_id) -> List[Paper]:
# grabs papers from the arxiv API endpoint.
# we need this because the RSS feed is buggy, and drops some of the most recent papers silently.
end_date = timestamp
start_date = timestamp-timedelta(days=4)
start_date = timestamp - timedelta(days=4)
search = arxiv.Search(
query="("+area+") AND submittedDate:[" + start_date.strftime('%Y%m%d') + "* TO " + end_date.strftime('%Y%m%d') + "*]",
query="("
+ area
+ ") AND submittedDate:["
+ start_date.strftime("%Y%m%d")
+ "* TO "
+ end_date.strftime("%Y%m%d")
+ "*]",
max_results=None,
sort_by=arxiv.SortCriterion.SubmittedDate
sort_by=arxiv.SortCriterion.SubmittedDate,
)
results = list(arxiv.Client().results(search))
api_papers = []
Expand All @@ -52,12 +58,16 @@ def get_papers_from_arxiv_api(area: str, timestamp, last_id) -> List[Paper]:
authors = [author.name for author in result.authors]
summary = result.summary
summary = unescape(re.sub("\n", " ", summary))
paper = Paper(authors=authors, title=result.title, abstract=summary, arxiv_id=result.get_short_id()[:10])
paper = Paper(
authors=authors,
title=result.title,
abstract=summary,
arxiv_id=result.get_short_id()[:10],
)
api_papers.append(paper)
return api_papers



def get_papers_from_arxiv_rss(area: str, config: Optional[dict]) -> List[Paper]:
# get the feed from http://export.arxiv.org/rss/ and use the updated timestamp to avoid duplicates
updated = datetime.utcnow() - timedelta(days=1)
Expand All @@ -73,7 +83,9 @@ def get_papers_from_arxiv_rss(area: str, config: Optional[dict]) -> List[Paper]:
return []
# get the list of entries
entries = feed.entries
timestamp = datetime.strptime(feed.headers["last-modified"], '%a, %d %b %Y %H:%M:%S GMT')
timestamp = datetime.strptime(
feed.headers["last-modified"], "%a, %d %b %Y %H:%M:%S GMT"
)
# ugly hack: this should be the very oldest paper in the RSS feed that was not put on hold.
# if ArXiv changes their RSS announcement format this line will break, but we have no other way of getting this info
last_id = feed.entries[0].id.split("/")[-1]
Expand Down Expand Up @@ -117,7 +129,6 @@ def get_papers_from_arxiv_rss_api(area: str, config: Optional[dict]) -> List[Pap
return merged_paper_list



if __name__ == "__main__":
paper_list = get_papers_from_arxiv_rss("math.AC", None)
print(paper_list)
Expand Down
27 changes: 19 additions & 8 deletions filter_papers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,18 @@ def call_chatgpt(full_prompt, openai_client, model, num_samples):
messages=[{"role": "user", "content": full_prompt}],
temperature=0.0,
n=int(num_samples),
seed=0
seed=0,
)


def run_and_parse_chatgpt(full_prompt, openai_client, config):
# just runs the chatgpt prompt, tries to parse the resulting JSON
completion = call_chatgpt(full_prompt, openai_client, config["SELECTION"]["model"], config["FILTERING"]["num_samples"])
completion = call_chatgpt(
full_prompt,
openai_client,
config["SELECTION"]["model"],
config["FILTERING"]["num_samples"],
)
json_dicts = defaultdict(list)
for choice in completion.choices:
out_text = choice.message.content
Expand All @@ -99,12 +104,18 @@ def run_and_parse_chatgpt(full_prompt, openai_client, config):
continue
all_dict = []
for id, json_list in json_dicts.items():
rel_score = sum([float(jdict["RELEVANCE"]) for jdict in json_list])/float(len(json_list))
nov_score = sum([float(jdict["NOVELTY"]) for jdict in json_list]) / float(len(json_list))
new_dict = {"ARXIVID":json_list[0]["ARXIVID"],
"COMMENT":json_list[0]["COMMENT"],
"RELEVANCE":rel_score,
"NOVELTY":nov_score}
rel_score = sum([float(jdict["RELEVANCE"]) for jdict in json_list]) / float(
len(json_list)
)
nov_score = sum([float(jdict["NOVELTY"]) for jdict in json_list]) / float(
len(json_list)
)
new_dict = {
"ARXIVID": json_list[0]["ARXIVID"],
"COMMENT": json_list[0]["COMMENT"],
"RELEVANCE": rel_score,
"NOVELTY": nov_score,
}
all_dict.append(new_dict)
return all_dict, calc_price(config["SELECTION"]["model"], completion.usage)

Expand Down
2 changes: 1 addition & 1 deletion parse_json_to_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def render_md_string(papers_dict):
# join all papers into one string
output_string = output_string + "\n".join(paper_strings)
output_string += "\n\n---\n\n"
output_string += f'## Paper selection prompt\n{criterion}'
output_string += f"## Paper selection prompt\n{criterion}"
return output_string


Expand Down

0 comments on commit 8c8051e

Please sign in to comment.