agents package

agents.register_agent(agent_name: str) Callable[[_T], _T]

the decorator to register an agent

to add a new class as a registered agent, ensure the class is a derivative type of agents.BaseAgent, add the following code to the top of the file:

@register_agent("agent_name")
YourAgentClass(BaseAgent):
    ...
Parameters:

agent_name (str) – the name of the registered agent

Returns:

the decorator to register the agent

Return type:

Callable[[_T], _T]

Submodules

agents.BaseAgent module

class agents.BaseAgent.BaseAgent(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False)

Bases: Generic[BaseEnv]

the base class for all agents

Parameters:

Generic (TypeVar) – the base type of the environment

Raises:

NotImplementedError – train method not implemented

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False) None

the constructor for the BaseAgent

Parameters:
  • args (argparse.Namespace) – arguments

  • env (BaseEnv) – the trading environment

  • device (Optional[str], optional) – torch device. Defaults to None, which means the device is automatically selected.

  • test_mode (bool, optional) – test or train mode. Defaults to False.

train() None

train the agent. Must be implemented by the subclass

Raises:

NotImplementedError – train method not implemented

test() None

test the agent. Must be implemented by the subclass

Raises:

NotImplementedError – test method not implemented

agents.DPG module

class agents.DPG.DPG(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False)

Bases: BaseAgent

The DPG class is a subclass of BaseAgent and implements the DPG algorithm.

Raises:

ValueError – missing model_load_path for testing

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False) None

the constructor for the DPG agent

Parameters:
  • args (argparse.Namespace) – arguments

  • env (BaseEnv) – the trading environment

  • device (Optional[str], optional) – torch device. Defaults to None, which means the device is automatically selected.

  • test_mode (bool, optional) – test or train mode. Defaults to False.

train() None

train the DPG agent

_update_model() float

update the model

test() None

test the agent. Must be implemented by the subclass

Raises:

NotImplementedError – test method not implemented

agents.DQN module

class agents.DQN.DQN(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False)

Bases: BaseAgent[BaseEnv]

The DQN class is a subclass of BaseAgent and implements the DQN algorithm.

Parameters:

BaseAgent (TypeVar) – the type of the environment

Raises:

ValueError – missing model_load_path for testing

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, env: BaseEnv, device: Optional[str] = None, test_mode: bool = False) None

the constructor for the DQN agent

Parameters:
  • args (argparse.Namespace) – arguments

  • env (BaseEnv) – the trading environment

  • device (Optional[str], optional) – torch device. Defaults to None, which means the device is automatically selected.

  • test_mode (bool, optional) – test or train mode. Defaults to False.

train() None

train the DQN agent

_update_epsilon() None

update the epsilon

_update_Q_network() float

random sample multiple experiences from replay buffer and update Q network

Returns:

the training loss

Return type:

float

_update_target_network() None

update the target network with the Q network weights

test() None

test the DQN agent

agents.MultiDQN module

class agents.MultiDQN.MultiDQN(args: Namespace, env: DiscreteRealDataEnv1, device: Optional[str] = None, test_mode: bool = False)

Bases: DQN[DiscreteRealDataEnv1]

The MultiDQN class is a subclass of DQN and implements the MultiDQN algorithm. It outputs the Q value of all possible actions simultaneously. It takes environment DiscreteRealDataEnv1 as it’s own env.

References

original paper: https://arxiv.org/abs/1907.03665

reference implementation: https://github.com/Jogima-cyber/portfolio-manager

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, env: DiscreteRealDataEnv1, device: Optional[str] = None, test_mode: bool = False) None

the constructor for the MultiDQN agent

Parameters:
  • args (argparse.Namespace) – arguments

  • env (BaseEnv) – the trading environment

  • device (Optional[str], optional) – torch device. Defaults to None, which means the device is automatically selected.

  • test_mode (bool, optional) – test or train mode. Defaults to False.

train() None

the train for multiDQN is composed of two steps: 1. pretrain the Q network 2. train the Q network using multiDQN

_pretrain() None

the pretraining step of the multiDQN algorithm

_multiDQN_train() None

the training step of the multiDQN algorithm

_update_Q_network() float

random sample multiple experience lists from replay buffer and update Q network

Returns:

the training loss

Return type:

float

test() None

test the MultiDQN agent