Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for https://github.com/bigscience-workshop/promptsource/issues/20 #21

Merged
9 changes: 8 additions & 1 deletion promptsource/promptsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import requests
import streamlit as st
from session import _get_state
from utils import _ADDITIONAL_ENGLISH_DATSETS
from utils import (_ADDITIONAL_ENGLISH_DATSETS, removeHyphen,
renameDatasetColumn)

from templates import Template, TemplateCollection

Expand Down Expand Up @@ -213,6 +214,10 @@ def list_datasets(template_collection, _priority_filter):
)
st.markdown(md)

st.sidebar.subheader("Dataset Schema")
st.sidebar.write(render_features(dataset.features))
dataset = renameDatasetColumn(dataset)

dataset_templates = template_collection.get_dataset(dataset_key, conf_option.name if conf_option else None)

template_list = dataset_templates.keys
Expand All @@ -230,6 +235,8 @@ def list_datasets(template_collection, _priority_filter):
example_index = st.sidebar.slider("Select the example index", 0, len(dataset) - 1)

example = dataset[example_index]
example = removeHyphen(example)

st.sidebar.write(example)

col1, _, col2 = st.beta_columns([18, 1, 6])
Expand Down
20 changes: 20 additions & 0 deletions promptsource/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,23 @@
"xtreme",
"yelp_polarity",
]


def removeHyphen(example):
example_clean = {}
for key in example.keys():
if "-" in key:
new_key = key.replace("-", "_")
example_clean[new_key] = example[key]
else:
example_clean[key] = example[key]
example = example_clean
return example


def renameDatasetColumn(dataset):
col_names = dataset.column_names
for cols in col_names:
if "-" in cols:
dataset = dataset.rename_column(cols, cols.replace("-", "_"))
return dataset