agents package¶
- agents.register_agent(agent_name: str) Callable[[_T], _T] ¶
the decorator to register an agent
to add a new class as a registered agent, ensure the class is a derivative type of agents.BaseAgent, add the following code to the top of the file:
@register_agent("agent_name") YourAgentClass(BaseAgent): ...
- Parameters:
agent_name (str) – the name of the registered agent
- Returns:
the decorator to register the agent
- Return type:
Callable[[_T], _T]
Submodules¶
agents.BaseAgent module¶
- class agents.BaseAgent.BaseAgent(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False)¶
Bases:
Generic
[BaseEnv
]the base class for all agents
- Parameters:
Generic (TypeVar) – the base type of the environment
- Raises:
NotImplementedError – train method not implemented
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False) None ¶
the constructor for the BaseAgent
- Parameters:
args (argparse.Namespace) – arguments
env (BaseEnv) – the trading environment
device (Optional[str], optional) – torch device. Defaults to None, which means the device is automatically selected.
test_mode (bool, optional) – test or train mode. Defaults to False.
- train() None ¶
train the agent. Must be implemented by the subclass
- Raises:
NotImplementedError – train method not implemented
- test() None ¶
test the agent. Must be implemented by the subclass
- Raises:
NotImplementedError – test method not implemented
agents.DPG module¶
- class agents.DPG.DPG(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False)¶
Bases:
BaseAgent
The DPG class is a subclass of BaseAgent and implements the DPG algorithm.
- Raises:
ValueError – missing model_load_path for testing
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False) None ¶
the constructor for the DPG agent
- Parameters:
args (argparse.Namespace) – arguments
env (BaseEnv) – the trading environment
device (Optional[str], optional) – torch device. Defaults to None, which means the device is automatically selected.
test_mode (bool, optional) – test or train mode. Defaults to False.
- train() None ¶
train the DPG agent
- _update_model() float ¶
update the model
- test() None ¶
test the agent. Must be implemented by the subclass
- Raises:
NotImplementedError – test method not implemented
agents.DQN module¶
- class agents.DQN.DQN(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False)¶
Bases:
BaseAgent
[BaseEnv
]The DQN class is a subclass of BaseAgent and implements the DQN algorithm.
- Parameters:
BaseAgent (TypeVar) – the type of the environment
- Raises:
ValueError – missing model_load_path for testing
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False) None ¶
the constructor for the DQN agent
- Parameters:
args (argparse.Namespace) – arguments
env (BaseEnv) – the trading environment
device (Optional[str], optional) – torch device. Defaults to None, which means the device is automatically selected.
test_mode (bool, optional) – test or train mode. Defaults to False.
- train() None ¶
train the DQN agent
- _update_epsilon() None ¶
update the epsilon
- _update_Q_network() float ¶
random sample multiple experiences from replay buffer and update Q network
- Returns:
the training loss
- Return type:
float
- _update_target_network() None ¶
update the target network with the Q network weights
- test() None ¶
test the DQN agent
agents.MultiDQN module¶
- class agents.MultiDQN.MultiDQN(args: Namespace, env: DiscreteRealDataEnv1, device: Optional[str] = None, test_mode: bool = False)¶
Bases:
DQN
[DiscreteRealDataEnv1
]The MultiDQN class is a subclass of DQN and implements the MultiDQN algorithm. It outputs the Q value of all possible actions simultaneously. It takes environment DiscreteRealDataEnv1 as it’s own env.
References
original paper: https://arxiv.org/abs/1907.03665
reference implementation: https://github.com/Jogima-cyber/portfolio-manager
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, env: DiscreteRealDataEnv1, device: Optional[str] = None, test_mode: bool = False) None ¶
the constructor for the MultiDQN agent
- Parameters:
args (argparse.Namespace) – arguments
env (BaseEnv) – the trading environment
device (Optional[str], optional) – torch device. Defaults to None, which means the device is automatically selected.
test_mode (bool, optional) – test or train mode. Defaults to False.
- train() None ¶
the train for multiDQN is composed of two steps: 1. pretrain the Q network 2. train the Q network using multiDQN
- _pretrain() None ¶
the pretraining step of the multiDQN algorithm
- _multiDQN_train() None ¶
the training step of the multiDQN algorithm
- _update_Q_network() float ¶
random sample multiple experience lists from replay buffer and update Q network
- Returns:
the training loss
- Return type:
float
- test() None ¶
test the MultiDQN agent