Skip to content

Commit

Permalink
Fix Lora switching (#967)
Browse files Browse the repository at this point in the history
fix: #936
  • Loading branch information
ccssu authored Jun 19, 2024
1 parent ff16d33 commit 1b90bbb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
2 changes: 1 addition & 1 deletion onediff_comfy_nodes/modules/oneflow/booster_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _(self, model: ModelPatcher, ckpt_name: Optional[str] = None, **kwargs):
)
set_compiled_options(compiled_model, graph_file)


model.weight_inplace_update = True
return model

@execute.register(ControlNet)
Expand Down
19 changes: 8 additions & 11 deletions src/onediff/infer_compiler/backends/oneflow/dual_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,18 @@ def __setattr__(self, name: str, value: Any) -> None:
if name in ["_torch_module", "_oneflow_module"]:
super().__setattr__(name, value)
else: # TODO: aviod memory up when set attr
_torch_module: torch.nn.Module = self._torch_module
module = self._torch_module
if (
hasattr(_torch_module, "_disable_param_update")
and _torch_module._disable_param_update
hasattr(module, "_disable_param_update")
and module._disable_param_update
):
return

if self._oneflow_module is not None:
v = torch2oflow(value)
if isinstance(v, flow.Tensor):
obj = getattr(self._oneflow_module, name)
obj.copy_(v)
else:
setattr(self._oneflow_module, name, v)
setattr(_torch_module, name, value)
torch_obj = getattr(module, name)
if hasattr(torch_obj, 'copy_'):
torch_obj.copy_(value)
else:
setattr(module, name, value)

def extra_repr(self) -> str:
return self._torch_module.extra_repr()
Expand Down

0 comments on commit 1b90bbb

Please sign in to comment.