iamwyldecat commited on
Commit
64757cb
·
1 Parent(s): 036642a

fix(muon): free tensors that are no longer needed

Browse files
Files changed (34) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  3. build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -7
  4. build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  6. build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +10 -7
  7. build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  9. build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -7
  10. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
  11. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  12. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +10 -7
  13. build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  15. build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +10 -7
  16. build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
  17. build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  18. build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +10 -7
  19. build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  20. build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  21. build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +10 -7
  22. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  23. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  24. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -7
  25. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  26. build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  27. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -7
  28. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  29. build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  30. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +10 -7
  31. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  32. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
  33. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +10 -7
  34. torch-ext/optimizer/muon.py +10 -7
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:98bd4b647ad0ecbae82a5e78f618475b47595c5bb68b3356c09ee8b1f1a57060
3
  size 1787272
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c77e5647b6056bfaee25050cca7948c40859db0a88fa4fcf40b67a85c947d8c
3
  size 1787272
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:796ac374cd2eec4260591c5a771c6b324f7dc6c8f34fc5dc211ab8afca546ffe
3
  size 1824224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94ea66089cc8d9eda72b017733a9e05e4fee5a2f04c50658b690d2c19f0d3068
3
  size 1824224
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:254706f111eb794b1409ba48d25649ace5438e2c66027727e84490011ee4c5e6
3
  size 1824224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46e01e1d957ada2d485b30cd60bc3ef7230b8857dffc59f2e7924339761ec577
3
  size 1824224
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:027a26212a3dd705876ca83015a53b69d17d80fe7c1559fb01d7aacf614edb57
3
  size 1749744
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a825a0cd31d8c1b91aa9db4b24248d7fc0a506615f625a385b40e6002025c7dd
3
  size 1749744
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:62c4408eaf54197941241ae6150afe1401a8bcf5854488a8b957d1f1546b388a
3
  size 1787192
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:579e9ddf66a4f17ead9232c2f32e6327fe6a3f16dd235e2e73e6cb282de1797e
3
  size 1787192
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:166d253c91459e1aa1328a1550b0e3ec4bb7c6057870b1d7472a93cc987cf85a
3
  size 1824184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:beacb4ba2d56463b6d444875728b3462cb3ff6c1449e3c9693cd665bfbbbbb73
3
  size 1824184
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8bb7315b326f9af7a77e023c2b78511190235a8dcc9682abd5b49db1dc2b90f2
3
  size 1824184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b04b011803d328d8dcd2edcf4c3840ddbb1bb2f093464c208f0ba2faf4f16bc
3
  size 1824184
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a74351ee471271eaf1c8292ed01b7e71e6b1b683704144d68d90b67032ba386
3
  size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad6c725009f2e776b99d3134c75f15e11dd7fe75fe4ba1fa94779018c7871f8c
3
  size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ffb7e3a786405106908da16e74506fe381b09e5e04a27b1062396e378f63f7f7
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50cb5819ff08a2179d78cd98164d07fd3cef1b66ee7703d599a310dfb140b9d1
3
  size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:45ee6c653f216af96705a25993d85751648ccd4714a8d6c8c36bdbc8dc19edc5
3
  size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c75e42265f382addc71327ad5628e8a2414da5872791c975e384708c4acd549
3
  size 1883352
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_febdf5b_dirty
3
- ops = torch.ops._optimizer_febdf5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_febdf5b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_036642a_dirty
3
+ ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_036642a_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8427dae3274100063f3b003a7cebf9565318fcaa2fa340482b2ec9408e9dcea0
3
  size 1749648
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a2363d4311d6a75fbcc03e6d4a71c73dae4d54e00a30135d25198d4078c6b0f
3
  size 1749648
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]:
torch-ext/optimizer/muon.py CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
53
 
54
 
55
  @torch.no_grad()
56
- def _gather(p, state, rank, comm_stream):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
- # TODO: Consider ,,,
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
  "Gather event already exists, which should not happen."
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
81
  else:
82
  state.gathered_grad = None
83
  state.gather_event = None
 
 
84
 
85
 
86
  @torch.no_grad()
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
94
  state.computed_u = u
95
  state.compute_event = torch.cuda.Event()
96
  state.compute_event.record()
97
- state.gathered_grad.record_stream(compute_stream)
98
- del state.gathered_grad
99
  else:
100
  state.computed_u = None
101
  state.compute_event = None
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
123
  group=mesh.get_group(),
124
  )
125
  if rank == state.worker_rank:
126
- state.computed_u.record_stream(comm_stream)
127
- del state.computed_u
128
  u = DTensor.from_local(
129
  u,
130
  placements=p.placements,
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
172
  adamw_wd=0.1,
173
  adamw_betas=(0.9, 0.95),
174
  adamw_eps=1e-8,
 
175
  debug=False,
176
  ):
177
  defaults = dict(
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
182
  ns_steps=ns_steps,
183
  adamw_betas=adamw_betas,
184
  adamw_eps=adamw_eps,
 
185
  )
186
 
187
  super().__init__(model.parameters(), defaults)
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
350
  def enqueue_gathers(start_idx, chunk_size):
351
  for p in ordered_params[start_idx : start_idx + chunk_size]:
352
  state = param_to_state[id(p)]
353
- _gather(p, state, self.rank, self.comm_stream)
354
 
355
  def enqueue_computes(start_idx, chunk_size):
356
  for p in ordered_params[start_idx : start_idx + chunk_size]:
 
53
 
54
 
55
  @torch.no_grad()
56
+ def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
  mesh = g.device_mesh
59
 
 
70
  group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
 
73
  if state.gathered_grad is not None:
74
  raise RuntimeError(
75
  "Gather event already exists, which should not happen."
 
80
  else:
81
  state.gathered_grad = None
82
  state.gather_event = None
83
+ if none_grad:
84
+ p.grad = None
85
 
86
 
87
  @torch.no_grad()
 
95
  state.computed_u = u
96
  state.compute_event = torch.cuda.Event()
97
  state.compute_event.record()
98
+ # Clear the gathered gradient to free memory
99
+ state.gathered_grad = None
100
  else:
101
  state.computed_u = None
102
  state.compute_event = None
 
124
  group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
+ # Clear u to free memory
128
+ state.computed_u = None
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
 
173
  adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
+ none_grad=True,
177
  debug=False,
178
  ):
179
  defaults = dict(
 
184
  ns_steps=ns_steps,
185
  adamw_betas=adamw_betas,
186
  adamw_eps=adamw_eps,
187
+ none_grad=none_grad,
188
  )
189
 
190
  super().__init__(model.parameters(), defaults)
 
353
  def enqueue_gathers(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
+ _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
357
 
358
  def enqueue_computes(start_idx, chunk_size):
359
  for p in ordered_params[start_idx : start_idx + chunk_size]: