Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed The issue with sumatra repeat and diff (only git) #356

Merged
merged 4 commits into from
Sep 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions sumatra/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def info(self):
def new_record(self, parameters={}, input_data=[], script_args="",
executable='default', repository='default',
main_file='default', version='current', launch_mode='default',
label=None, reason=None, timestamp_format='default'):
diff='', label=None, reason=None, timestamp_format='default'):
logger.debug("Creating new record")
if executable == 'default':
executable = deepcopy(self.default_executable)
Expand All @@ -193,7 +193,7 @@ def new_record(self, parameters={}, input_data=[], script_args="",
if timestamp_format == 'default':
timestamp_format = self.timestamp_format
working_copy = repository.get_working_copy()
version, diff = self.update_code(working_copy, version)
version, diff = self.update_code(working_copy, version, diff)
if label is None:
label = LABEL_GENERATORS[self.label_generator]()
record = Record(executable, repository, main_file, version, launch_mode,
Expand All @@ -212,15 +212,13 @@ def new_record(self, parameters={}, input_data=[], script_args="",

def launch(self, parameters={}, input_data=[], script_args="",
executable='default', repository='default', main_file='default',
version='current', launch_mode='default', label=None, reason=None,
version='current', launch_mode='default', diff='', label=None, reason=None,
timestamp_format='default', repeats=None):
"""Launch a new simulation or analysis."""
record = self.new_record(parameters, input_data, script_args,
executable, repository, main_file, version,
launch_mode, label, reason, timestamp_format)

launch_mode, diff, label, reason, timestamp_format)
record.run(with_label=self.data_label, project=self)

if 'matlab' in record.executable.name.lower():
record.register(record.repository.get_working_copy())
if repeats:
Expand All @@ -230,26 +228,29 @@ def launch(self, parameters={}, input_data=[], script_args="",
self.save()
return record.label

def update_code(self, working_copy, version='current'):
def update_code(self, working_copy, version='current', diff=''):
"""Check if the working copy has modifications and prompt to commit or revert them."""
# we really need to extend this to the dependencies, but we need to take extra special care that the
# code ends up in the same condition as before the run
logger.debug("Updating working copy to use version: %s" % version)
diff = ''
changed = working_copy.has_changed()
if version == 'current' or version == working_copy.current_version:
if (version == 'current' or version == working_copy.current_version) and not diff:
if changed:
if self.on_changed == "error":
raise UncommittedModificationsError("Code has changed, please commit your changes")
elif self.on_changed == "store-diff":
diff = working_copy.diff()
else:
raise ValueError("store-diff must be either 'error' or 'store-diff'")
elif changed:
raise UncommittedModificationsError(
"Code has changed. These changes will be lost when switching "
"to a different version, so please commit or stash your "
"changes and then retry.")
elif diff:
if changed:
raise UncommittedModificationsError(
"Code has changed. These changes will be lost when switching "
"to a different version, so please commit or stash your "
"changes and then retry.")
else:
working_copy.use_version(version)
working_copy.patch(diff)
elif version == 'latest':
working_copy.use_latest_version()
else:
Expand Down Expand Up @@ -395,9 +396,11 @@ def repeat(self, original_label, new_label=None):
repository=original.repository,
version=original.version,
launch_mode=original.launch_mode,
diff=original.diff,
label=new_label,
reason="Repeat experiment %s" % original.label,
repeats=original.label)
working_copy.reset()
working_copy.use_version(current_version) # ensure we switch back to the original working copy state
return new_label, original.label

Expand Down
23 changes: 23 additions & 0 deletions sumatra/versioncontrol/_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import git
import os
import shutil
import tempfile
from distutils.version import LooseVersion
from configparser import NoSectionError, NoOptionError
try:
Expand Down Expand Up @@ -102,6 +103,28 @@ def diff(self):
g = git.Git(self.path)
return g.diff('HEAD', color='never')

def reset(self):
"""Resets all uncommitted changes since the commit. Destructive, be
careful with use"""
g = git.Git(self.path)
g.reset('HEAD', '--hard')

def patch(self, diff):
"""Resets all uncommitted changes since the commit. Destructive, be
careful with use"""
assert not self.has_changed(), "Cannot patch dirty working copy"
# Create temp patch file
if diff[-1] != '\n':
diff = diff + '\n'
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file:
temp_file.write(diff)
temp_file_name = temp_file.name
try:
g = git.Git(self.path)
g.apply(temp_file_name)
finally:
os.remove(temp_file_name)

def content(self, digest, filename):
"""Get the file content from repository."""
repo = git.Repo(self.path)
Expand Down
10 changes: 10 additions & 0 deletions sumatra/versioncontrol/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ def diff(self):
"""Return the difference between working copy and repository."""
raise NotImplementedError

def reset(self):
"""Resets all uncommitted changes since the commit. Destructive, be
careful with use"""
raise NotImplementedError

def patch(self, diff):
"""Applies the diff patch onto the repository files. Only works on a
clean working copy"""
raise NotImplementedError

def get_username(self):
"""
Return the username and e-mail of the current user, as understood by the
Expand Down