Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extracting all patent counts to pipeline; modify other code for extensibility to enable this #198

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ jobs:
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# - name: Test with pytest
# run: |
# pip install pytest
# cd web
# python3 -m pytest tests
- name: Test with pytest
run: |
pip install pytest
cd web
python3 -m pytest tests
# cd ..
# python3 -m pytest company_linkage/test_aggregate_organizations.py
# - name: Report python coverage
# uses: orgoro/coverage@v3
# with:
# coverageFile: coverage/python.xml
# token: ${{ secrets.GITHUB_TOKEN }}
- name: Report python coverage
uses: orgoro/coverage@v3
with:
coverageFile: coverage/python.xml
token: ${{ secrets.GITHUB_TOKEN }}
- name: Report javascript coverage
uses: MishaKav/[email protected]
with:
Expand Down
57 changes: 42 additions & 15 deletions company_linkage/parat_data_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@
)

run_papers = []
for paper_type in ["top", "all"]:
for paper_type in ["top_paper", "highly_cited_paper", "all_paper", "all_patent"]:

run_get_paper_counts = GKEStartPodOperator(
task_id=f"run_get_{paper_type}_counts",
Expand All @@ -221,9 +221,9 @@
cluster_name="cc2-task-pool",
name=f"run_get_{paper_type}_counts",
cmds=["/bin/bash"],
arguments=["-c", (f"echo 'getting {paper_type} paper counts!' ; rm -r {paper_type} || true ; "
arguments=["-c", (f"echo 'getting {paper_type} counts!' ; rm -r {paper_type} || true ; "
f"mkdir -p {paper_type} && "
f"python3 {paper_type}_papers.py {paper_type}/{paper_type}_paper_counts.jsonl && "
f"python3 {paper_type}s.py {paper_type}/{paper_type}_counts.jsonl && "
f"gsutil -m cp -r {paper_type} gs://{DATA_BUCKET}/{tmp_dir}/ ")],
namespace="default",
image=f"us.gcr.io/{PROJECT_ID}/parat",
Expand Down Expand Up @@ -254,25 +254,47 @@
load_top_papers = GCSToBigQueryOperator(
task_id=f"load_top_papers",
bucket=DATA_BUCKET,
source_objects=[f"{tmp_dir}/top/top_paper_counts.jsonl"],
source_objects=[f"{tmp_dir}/top_paper/top_paper_counts.jsonl"],
schema_object=f"{schema_dir}/top_papers_schema.json",
destination_project_dataset_table=f"{staging_dataset}.top_paper_counts",
source_format="NEWLINE_DELIMITED_JSON",
create_disposition="CREATE_IF_NEEDED",
write_disposition="WRITE_TRUNCATE"
)

load_highly_cited_papers = GCSToBigQueryOperator(
task_id=f"load_highly_cited_papers",
bucket=DATA_BUCKET,
source_objects=[f"{tmp_dir}/highly_cited_paper/highly_cited_paper_counts.jsonl"],
schema_object=f"{schema_dir}/highly_cited_papers_schema.json",
destination_project_dataset_table=f"{staging_dataset}.highly_cited_paper_counts",
source_format="NEWLINE_DELIMITED_JSON",
create_disposition="CREATE_IF_NEEDED",
write_disposition="WRITE_TRUNCATE"
)

load_all_papers = GCSToBigQueryOperator(
task_id=f"load_all_papers",
bucket=DATA_BUCKET,
source_objects=[f"{tmp_dir}/all/all_paper_counts.jsonl"],
source_objects=[f"{tmp_dir}/all_paper/all_paper_counts.jsonl"],
schema_object=f"{schema_dir}/all_papers_schema.json",
destination_project_dataset_table=f"{staging_dataset}.all_paper_counts",
source_format="NEWLINE_DELIMITED_JSON",
create_disposition="CREATE_IF_NEEDED",
write_disposition="WRITE_TRUNCATE"
)

load_all_patents = GCSToBigQueryOperator(
rggelles marked this conversation as resolved.
Show resolved Hide resolved
task_id=f"load_all_patents",
bucket=DATA_BUCKET,
source_objects=[f"{tmp_dir}/all_patent/all_patent_counts.jsonl"],
schema_object=f"{schema_dir}/all_patents_schema.json",
destination_project_dataset_table=f"{staging_dataset}.all_patent_counts",
source_format="NEWLINE_DELIMITED_JSON",
create_disposition="CREATE_IF_NEEDED",
write_disposition="WRITE_TRUNCATE"
)

start_visualization_tables = DummyOperator(task_id="start_visualization_tables")
wait_for_visualization_tables = DummyOperator(task_id="wait_for_visualization_tables")

Expand Down Expand Up @@ -319,7 +341,7 @@

curr_date = datetime.now().strftime('%Y%m%d')
prod_tables = ["visualization_data", "paper_visualization_data",
"patent_visualization_data", "workforce_visualization_data"]
"patent_visualization_data", "workforce_visualization_data", "all_visualization_data"]
for table in prod_tables:
prod_table_name = f"{production_dataset}.{table}"
copy_to_production = BigQueryToBigQueryOperator(
Expand All @@ -329,22 +351,25 @@
create_disposition="CREATE_IF_NEEDED",
write_disposition="WRITE_TRUNCATE"
)
pop_descriptions = PythonOperator(
task_id="populate_column_documentation_for_" + table,
op_kwargs={
"input_schema": f"{os.environ.get('DAGS_FOLDER')}/schemas/parat/{table}.json",
"table_name": prod_table_name
},
python_callable=update_table_descriptions
)
table_backup = BigQueryToBigQueryOperator(
task_id=f"back_up_{table}",
source_project_dataset_tables=[f"{staging_dataset}.{table}"],
destination_project_dataset_table=f"{backups_dataset}.{table}_{curr_date}",
create_disposition="CREATE_IF_NEEDED",
write_disposition="WRITE_TRUNCATE"
)
wait_for_checks >> copy_to_production >> pop_descriptions >> table_backup >> wait_for_copy
if table != "all_visualization_data":
pop_descriptions = PythonOperator(
task_id="populate_column_documentation_for_" + table,
op_kwargs={
"input_schema": f"{os.environ.get('DAGS_FOLDER')}/schemas/parat/{table}.json",
"table_name": prod_table_name
},
python_callable=update_table_descriptions
)
wait_for_checks >> copy_to_production >> pop_descriptions >> table_backup >> wait_for_copy
else:
wait_for_checks >> copy_to_production >> table_backup >> wait_for_copy

# post success to slack
msg_success = get_post_success("PARAT tables updated!", dag)
Expand All @@ -365,7 +390,9 @@
>> load_ai_patent_grants
>> run_papers
>> load_top_papers
>> load_highly_cited_papers
>> load_all_papers
>> load_all_patents
>> start_visualization_tables
)
(
Expand Down
28 changes: 28 additions & 0 deletions company_linkage/parat_scripts/all_patents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse

from get_ai_counts import CountGetter


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("output_file", type=str,
help="A jsonl file for writing output data to create new tables")
args = parser.parse_args()
if not args.output_file:
rggelles marked this conversation as resolved.
Show resolved Hide resolved
parser.print_help()
return
if "jsonl" not in args.output_file:
parser.print_help()
return
patent_finder = CountGetter()
patent_finder.get_identifiers()
# These are the only two lines that make this different from running AI patents
# We select from a different table and AI is false
table_name = "linked_all_patents"
# And we write out our data to a different variable
companies = patent_finder.run_query_id_patents(table_name, ai=False, test=True)
rggelles marked this conversation as resolved.
Show resolved Hide resolved
patent_finder.write_output(companies, args.output_file)


if __name__ == "__main__":
main()
51 changes: 31 additions & 20 deletions company_linkage/parat_scripts/get_ai_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def run_query_papers(self, table_name: str, field_name: str, test: bool = False,
if by_year and not field_name_by_year in row_dict:
row_dict[field_name_by_year] = []
companies.append(row_dict)
if test and i == 25:
break
return companies

def run_query_papers_by_year(self, table_name: str, field_name: str, regexes: list, rors: list) -> list:
Expand Down Expand Up @@ -160,50 +162,57 @@ def run_query_id_papers(self, table_name: str, test: bool = False) -> list:
:param test: False if not running as a unit test
:return:
"""
companies_query = f"""SELECT CSET_id, ror_id FROM
`gcp-cset-projects.high_resolution_entities.aggregated_organizations`"""
if test:
companies_query += """ LIMIT 25"""
companies_query = f"""-- SELECT CSET_id, ror_id FROM
rggelles marked this conversation as resolved.
Show resolved Hide resolved
# `gcp-cset-projects.high_resolution_entities.aggregated_organizations`"""
# if test:
# companies_query += """ LIMIT 25"""
client = bigquery.Client()
query_job = client.query(companies_query)
# query_job = client.query(companies_query)
company_rows = []
for i, row in enumerate(query_job):
self.company_ids.append(row["CSET_id"])
if row["CSET_id"] not in self.regex_dict:
for i, cset_id in enumerate(self.cset_ids):
if test and i == 25:
break
# self.company_ids.append(cset_id)
if cset_id not in self.regex_dict:
# if it's not in the regex_dict that's bad
print(row["CSET_id"])
print(cset_id)
else:
regexes = self.regex_dict[row['CSET_id']]
regexes = self.regex_dict[cset_id]
query = f"""SELECT DISTINCT merged_id, year, cv, nlp, robotics FROM `{table_name}`
WHERE regexp_contains(org_name, r'(?i){regexes[0]}') """
# if we have more than one regex for an org, include all of them
if len(regexes) > 1:
for regex in regexes[1:]:
query += f"""OR regexp_contains(org_name, r'(?i){regex}') """
if row["ror_id"]:
self.ror_dict[row["CSET_id"]] = row["ror_id"]
query += f"""OR ror_id IN ({str(row["ror_id"])[1:-1]})"""
if cset_id in self.ror_dict:
# self.ror_dict[row["CSET_id"]] = row["ror_id"]
rggelles marked this conversation as resolved.
Show resolved Hide resolved
query += f"""OR ror_id IN ({str(self.ror_dict[cset_id])[1:-1]})"""
query_job = client.query(query)
# get all the merged ids
for element in query_job:
company_rows.append({"CSET_id": row["CSET_id"], "merged_id": element["merged_id"],
company_rows.append({"CSET_id": cset_id, "merged_id": element["merged_id"],
"year": element["year"], "cv": element["cv"],
"nlp": element["nlp"], "robotics": element["robotics"]})
return company_rows

def run_query_id_patents(self, table_name: str):
def run_query_id_patents(self, table_name: str, ai: bool = True, test: bool = False):
"""
Get patent counts one by one using CSET_ids.
:return:
"""
patent_companies = []
for cset_id in self.company_ids:
for i, cset_id in enumerate(self.cset_ids):
if test and i == 25:
break
if cset_id in self.regex_dict:
regexes = self.regex_dict[cset_id]
rors = self.ror_dict[cset_id]
query = f"""SELECT DISTINCT
family_id,
family_id,
priority_year,
"""
if ai:
query += f"""
Physical_Sciences_and_Engineering,
Life_Sciences,
Security__eg_cybersecurity,
Expand Down Expand Up @@ -238,7 +247,8 @@ def run_query_id_patents(self, table_name: str):
Probabilistic_Reasoning,
Ontology_Engineering,
Machine_Learning,
Search_Methods
Search_Methods """
query += f"""
FROM
staging_ai_companies_visualization.{table_name}
WHERE regexp_contains(assignee, r'(?i){regexes[0]}') """
Expand All @@ -252,8 +262,9 @@ def run_query_id_patents(self, table_name: str):
query_job = client.query(query)
for row in query_job:
new_patent_row = {"CSET_id": cset_id, "family_id": row["family_id"], "priority_year": row["priority_year"]}
patent_field_data = {i : row[i] for i in self.patent_fields}
new_patent_row.update(patent_field_data)
if ai:
patent_field_data = {i : row[i] for i in self.patent_fields}
new_patent_row.update(patent_field_data)
patent_companies.append(new_patent_row)
# company["ai_patents_by_year"] = self.run_query_patents_by_year(company["CSET_id"])
else:
Expand Down
28 changes: 28 additions & 0 deletions company_linkage/parat_scripts/highly_cited_papers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse

from get_ai_counts import CountGetter


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("output_file", type=str,
help="A jsonl file for writing output data to create new tables")
args = parser.parse_args()
if not args.output_file:
parser.print_help()
return
if "jsonl" not in args.output_file:
parser.print_help()
return
paper_finder = CountGetter()
paper_finder.get_identifiers()
# These are the only two lines that make this different from running AI pubs
# We select from a different table
table_name = "staging_ai_companies_visualization.highly_cited_ai_publications"
# And we write out our data to a different variable
companies = paper_finder.run_query_papers(table_name, "highly_cited_ai_pubs", by_year=True)
paper_finder.write_output(companies, args.output_file)


if __name__ == "__main__":
main()
8 changes: 4 additions & 4 deletions company_linkage/parat_scripts/test_ai_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from get_ai_counts import CountGetter
import warnings


def ignore_warnings(test_func):
def do_test(self, *args, **kwargs):
with warnings.catch_warnings():
Expand Down Expand Up @@ -47,7 +46,7 @@ def test_run_query_papers(self):
self.assertIsNotNone(company["ai_pubs"])

"""
This is deprecated and can no longer be tested this because
This is deprecated and can no longer be tested this because
the by-year data isn't necessarily in the visualization table.
TODO: Find a new way to test
"""
Expand Down Expand Up @@ -86,8 +85,9 @@ def test_run_query_id_patents(self):
count_getter.get_identifiers()
table_name = "gcp-cset-projects.staging_ai_companies_visualization.ai_publications"
test = True
count_getter.run_query_id_papers(table_name, test)
patent_companies = count_getter.run_query_id_patents()
ai = True
# count_getter.run_query_id_papers(table_name, test)
patent_companies = count_getter.run_query_id_patents("linked_ai_patents", ai, test)
for company_row in patent_companies:
self.assertIsNotNone(company_row["CSET_id"])
self.assertEqual(type(company_row["CSET_id"]), int)
Expand Down
Loading
Loading