diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index 2052a79987..d8edd44459 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -100,6 +100,7 @@ # Expose all document classification datasets from .document_classification import ( + AGNEWS, AMAZON_REVIEWS, COMMUNICATIVE_FUNCTIONS, GERMEVAL_2018_OFFENSIVE_LANGUAGE, @@ -314,6 +315,7 @@ "SentenceDataset", "MongoDataset", "StringDataset", + "AGNEWS", "ANAT_EM", "AZDZ", "BC2GM", diff --git a/flair/datasets/document_classification.py b/flair/datasets/document_classification.py index 2c0d6b3416..0bbc471818 100644 --- a/flair/datasets/document_classification.py +++ b/flair/datasets/document_classification.py @@ -907,6 +907,74 @@ def __init__( super().__init__(data_folder, tokenizer=tokenizer, memory_mode=memory_mode, **corpusargs) +class AGNEWS(ClassificationCorpus): + """The AG's News Topic Classification Corpus, classifying news into 4 coarse-grained topics. + + Labels: World, Sports, Business, Sci/Tech. + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(), + memory_mode="partial", + **corpusargs, + ): + """Instantiates AGNews Classification Corpus with 4 classes. + + :param base_path: Provide this only if you store the AGNEWS corpus in a specific folder, otherwise use default. + :param tokenizer: Custom tokenizer to use (default is SpaceTokenizer) + :param memory_mode: Set to 'partial' by default. Can also be 'full' or 'none'. + :param corpusargs: Other args for ClassificationCorpus. + """ + base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + # download data from same source as in huggingface's implementations + agnews_path = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/" + + original_filenames = ["train.csv", "test.csv", "classes.txt"] + new_filenames = ["train.txt", "test.txt"] + + for original_filename in original_filenames: + cached_path(f"{agnews_path}{original_filename}", Path("datasets") / dataset_name / "original") + + data_file = data_folder / new_filenames[0] + label_dict = [] + label_path = original_filenames[-1] + + # read label order + with open(data_folder / "original" / label_path) as f: + for line in f: + line = line.rstrip() + label_dict.append(line) + + original_filenames = original_filenames[:-1] + if not data_file.is_file(): + for original_filename, new_filename in zip(original_filenames, new_filenames): + with open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, open( + data_folder / new_filename, "w", encoding="utf-8" + ) as write_fp: + csv_reader = csv.reader( + open_fp, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, skipinitialspace=True + ) + for id_, row in enumerate(csv_reader): + label, title, description = row + # Original labels are [1, 2, 3, 4] -> ['World', 'Sports', 'Business', 'Sci/Tech'] + # Re-map to [0, 1, 2, 3]. + text = " ".join((title, description)) + + new_label = "__label__" + new_label += label_dict[int(label) - 1] + + write_fp.write(f"{new_label} {text}\n") + + super().__init__(data_folder, label_type="topic", tokenizer=tokenizer, memory_mode=memory_mode, **corpusargs) + + class STACKOVERFLOW(ClassificationCorpus): """Stackoverflow corpus classifying questions into one of 20 labels.