Commit
·
64757cb
1
Parent(s):
036642a
fix(muon): free tensors that are no longer needed
Browse files- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +10 -7
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -7
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -7
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +10 -7
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +10 -7
- torch-ext/optimizer/muon.py +10 -7
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1787272
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c77e5647b6056bfaee25050cca7948c40859db0a88fa4fcf40b67a85c947d8c
|
3 |
size 1787272
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824224
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94ea66089cc8d9eda72b017733a9e05e4fee5a2f04c50658b690d2c19f0d3068
|
3 |
size 1824224
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824224
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46e01e1d957ada2d485b30cd60bc3ef7230b8857dffc59f2e7924339761ec577
|
3 |
size 1824224
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1749744
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a825a0cd31d8c1b91aa9db4b24248d7fc0a506615f625a385b40e6002025c7dd
|
3 |
size 1749744
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1787192
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:579e9ddf66a4f17ead9232c2f32e6327fe6a3f16dd235e2e73e6cb282de1797e
|
3 |
size 1787192
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824184
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:beacb4ba2d56463b6d444875728b3462cb3ff6c1449e3c9693cd665bfbbbbb73
|
3 |
size 1824184
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824184
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b04b011803d328d8dcd2edcf4c3840ddbb1bb2f093464c208f0ba2faf4f16bc
|
3 |
size 1824184
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1787368
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad6c725009f2e776b99d3134c75f15e11dd7fe75fe4ba1fa94779018c7871f8c
|
3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824256
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50cb5819ff08a2179d78cd98164d07fd3cef1b66ee7703d599a310dfb140b9d1
|
3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1883352
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c75e42265f382addc71327ad5628e8a2414da5872791c975e384708c4acd549
|
3 |
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_036642a_dirty
|
3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1749648
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a2363d4311d6a75fbcc03e6d4a71c73dae4d54e00a30135d25198d4078c6b0f
|
3 |
size 1749648
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
torch-ext/optimizer/muon.py
CHANGED
@@ -53,7 +53,7 @@ class _muon_state:
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
-
def _gather(p, state, rank, comm_stream):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
-
# TODO: Consider ,,,
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
"Gather event already exists, which should not happen."
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
81 |
else:
|
82 |
state.gathered_grad = None
|
83 |
state.gather_event = None
|
|
|
|
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
94 |
state.computed_u = u
|
95 |
state.compute_event = torch.cuda.Event()
|
96 |
state.compute_event.record()
|
97 |
-
|
98 |
-
|
99 |
else:
|
100 |
state.computed_u = None
|
101 |
state.compute_event = None
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
123 |
group=mesh.get_group(),
|
124 |
)
|
125 |
if rank == state.worker_rank:
|
126 |
-
|
127 |
-
|
128 |
u = DTensor.from_local(
|
129 |
u,
|
130 |
placements=p.placements,
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
172 |
adamw_wd=0.1,
|
173 |
adamw_betas=(0.9, 0.95),
|
174 |
adamw_eps=1e-8,
|
|
|
175 |
debug=False,
|
176 |
):
|
177 |
defaults = dict(
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
182 |
ns_steps=ns_steps,
|
183 |
adamw_betas=adamw_betas,
|
184 |
adamw_eps=adamw_eps,
|
|
|
185 |
)
|
186 |
|
187 |
super().__init__(model.parameters(), defaults)
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
350 |
def enqueue_gathers(start_idx, chunk_size):
|
351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
352 |
state = param_to_state[id(p)]
|
353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
354 |
|
355 |
def enqueue_computes(start_idx, chunk_size):
|
356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
mesh = g.device_mesh
|
59 |
|
|
|
70 |
group=mesh.get_group(),
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
|
|
73 |
if state.gathered_grad is not None:
|
74 |
raise RuntimeError(
|
75 |
"Gather event already exists, which should not happen."
|
|
|
80 |
else:
|
81 |
state.gathered_grad = None
|
82 |
state.gather_event = None
|
83 |
+
if none_grad:
|
84 |
+
p.grad = None
|
85 |
|
86 |
|
87 |
@torch.no_grad()
|
|
|
95 |
state.computed_u = u
|
96 |
state.compute_event = torch.cuda.Event()
|
97 |
state.compute_event.record()
|
98 |
+
# Clear the gathered gradient to free memory
|
99 |
+
state.gathered_grad = None
|
100 |
else:
|
101 |
state.computed_u = None
|
102 |
state.compute_event = None
|
|
|
124 |
group=mesh.get_group(),
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
+
# Clear u to free memory
|
128 |
+
state.computed_u = None
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
|
|
173 |
adamw_wd=0.1,
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
+
none_grad=True,
|
177 |
debug=False,
|
178 |
):
|
179 |
defaults = dict(
|
|
|
184 |
ns_steps=ns_steps,
|
185 |
adamw_betas=adamw_betas,
|
186 |
adamw_eps=adamw_eps,
|
187 |
+
none_grad=none_grad,
|
188 |
)
|
189 |
|
190 |
super().__init__(model.parameters(), defaults)
|
|
|
353 |
def enqueue_gathers(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
357 |
|
358 |
def enqueue_computes(start_idx, chunk_size):
|
359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|