Skip to content

Commit b4db842

Browse files
committed
Pre-commit
1 parent 8884231 commit b4db842

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

maro/rl/model/fc_block.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,13 @@ def __init__(
7272

7373
# build the net
7474
dims = [self._input_dim] + self._hidden_dims
75-
layers = [self._build_layer(in_dim, out_dim, activation=self._activation) for in_dim, out_dim in zip(dims, dims[1:])]
75+
layers = [
76+
self._build_layer(in_dim, out_dim, activation=self._activation) for in_dim, out_dim in zip(dims, dims[1:])
77+
]
7678
# top layer
77-
layers.append(self._build_layer(dims[-1], self._output_dim, head=self._head, activation=self._output_activation))
79+
layers.append(
80+
self._build_layer(dims[-1], self._output_dim, head=self._head, activation=self._output_activation),
81+
)
7882

7983
self._net = nn.Sequential(*layers)
8084

@@ -103,7 +107,13 @@ def input_dim(self) -> int:
103107
def output_dim(self) -> int:
104108
return self._output_dim
105109

106-
def _build_layer(self, input_dim: int, output_dim: int, head: bool = False, activation: Type[torch.nn.Module] = None) -> nn.Module:
110+
def _build_layer(
111+
self,
112+
input_dim: int,
113+
output_dim: int,
114+
head: bool = False,
115+
activation: Type[torch.nn.Module] = None,
116+
) -> nn.Module:
107117
"""Build a basic layer.
108118
109119
BN -> Linear -> Activation -> Dropout

tests/rl/gym_wrapper/rl_component_bundle.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
77
from maro.simulator import Env
88

9-
from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine
10-
119
from .config import algorithm, env_conf
1210
from .env_sampler import GymEnvSampler
13-
11+
from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine
1412

1513
learn_env = Env(business_engine_cls=GymBusinessEngine, **env_conf)
1614
test_env = learn_env
@@ -44,7 +42,14 @@
4442
from tests.rl.algorithms.sac import get_sac_policy, get_sac_trainer
4543

4644
policies = [
47-
get_sac_policy(f"{algorithm}_{i}.policy", action_lower_bound, action_upper_bound, gym_state_dim, gym_action_dim, action_limit)
45+
get_sac_policy(
46+
f"{algorithm}_{i}.policy",
47+
action_lower_bound,
48+
action_upper_bound,
49+
gym_state_dim,
50+
gym_action_dim,
51+
action_limit,
52+
)
4853
for i in range(num_agents)
4954
]
5055
trainers = [get_sac_trainer(f"{algorithm}_{i}", gym_state_dim, gym_action_dim) for i in range(num_agents)]

0 commit comments

Comments
 (0)